Skip to content

ClusteringMetrics API

ClusteringMetrics

A collection of clustering and classification metrics, both unsupervised and supervised.

This class provides methods such as:

  • KL divergence for comparing two GMMs (via Monte Carlo).
  • Information criteria (AIC, BIC) for model selection.
  • Unsupervised metrics (silhouette, Davies-Bouldin, Calinski-Harabasz, Dunn index).
  • Supervised metrics (Rand index, ARI, mutual info variants, purity, classification report).

All methods are static for convenience, and most accept data in PyTorch Tensors (potentially on GPU). For supervised metrics, the user must provide labels_true and labels_pred as integer-encoded 1D tensors of the same shape.

Example

.. code-block:: python

from myproject.clustering_metrics import ClusteringMetrics

# Suppose gmm is a fitted GaussianMixture, gmm2 is another GMM
# Compare them via KL divergence:
kl_pq = ClusteringMetrics.kl_divergence_gmm(gmm, gmm2)
print("KL(p||q) =", kl_pq)

# Evaluate clustering performance w.r.t. true labels:
pred_labels = gmm.predict(X_tensor)
ari = ClusteringMetrics.adjusted_rand_score(labels_true, pred_labels)
print("ARI =", ari)
Source code in tgmm/metrics.py
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
class ClusteringMetrics:
    r"""
    A collection of clustering and classification metrics, both unsupervised and supervised.

    This class provides methods such as:

    - **KL divergence** for comparing two GMMs (via Monte Carlo).
    - **Information criteria** (AIC, BIC) for model selection.
    - **Unsupervised metrics** (silhouette, Davies-Bouldin, Calinski-Harabasz, Dunn index).
    - **Supervised metrics** (Rand index, ARI, mutual info variants, purity, classification report).

    All methods are static for convenience, and most accept data in PyTorch Tensors
    (potentially on GPU). For supervised metrics, the user must provide `labels_true`
    and `labels_pred` as integer-encoded 1D tensors of the same shape.

    Example
    -------
    .. code-block:: python

        from myproject.clustering_metrics import ClusteringMetrics

        # Suppose gmm is a fitted GaussianMixture, gmm2 is another GMM
        # Compare them via KL divergence:
        kl_pq = ClusteringMetrics.kl_divergence_gmm(gmm, gmm2)
        print("KL(p||q) =", kl_pq)

        # Evaluate clustering performance w.r.t. true labels:
        pred_labels = gmm.predict(X_tensor)
        ari = ClusteringMetrics.adjusted_rand_score(labels_true, pred_labels)
        print("ARI =", ari)
    """

    @staticmethod
    def bic_score(
        lower_bound_: float,
        X: torch.Tensor,
        n_components: int,
        covariance_type: str
    ) -> float:
        r"""
        Compute the Bayesian Information Criterion (BIC) for a GMM given its
        average log-likelihood (lower bound).

        $ \text{BIC} = n_{\text{params}} \ln(n_{\text{samples}}) - 2 \times \text{log\_likelihood} $

        Parameters
        ----------
        lower_bound_ : float
            Average (per-sample) log-likelihood or lower bound from the GMM.
        X : torch.Tensor
            Data used in fitting, shape (n_samples, n_features).
        n_components : int
            Number of mixture components in the GMM.
        covariance_type : str
            Covariance type, one of {'full', 'tied', 'diag', 'spherical'}.

        Returns
        -------
        float
            The BIC score (lower is better).
        """
        n_samples, n_features = X.shape

        # Determine number of free parameters in covariance
        if covariance_type == 'full':
            cov_params = n_components * n_features * (n_features + 1) / 2.0
        elif covariance_type == 'diag':
            cov_params = n_components * n_features
        elif covariance_type == 'tied':
            cov_params = n_features * (n_features + 1) / 2.0
        elif covariance_type == 'spherical':
            cov_params = n_components
        else:
            raise ValueError(f"Unsupported covariance type: {covariance_type}")

        # Means + weights
        mean_params = n_features * n_components
        weight_params = n_components - 1

        n_parameters = cov_params + mean_params + weight_params
        log_likelihood = lower_bound_ * n_samples  # total log-likelihood

        bic = n_parameters * torch.log(torch.tensor(n_samples, dtype=torch.float)) - 2.0 * log_likelihood
        return bic.item()

    @staticmethod
    def aic_score(
        lower_bound_: float,
        X: torch.Tensor,
        n_components: int,
        covariance_type: str
    ) -> float:
        r"""
        Compute the Akaike Information Criterion (AIC) for a GMM.

        $ \text{AIC} = 2 \times n_{\text{params}} - 2 \times \text{log\_likelihood} $

        Parameters
        ----------
        lower_bound_ : float
            Average (per-sample) log-likelihood from the GMM.
        X : torch.Tensor
            Data used in fitting, shape (n_samples, n_features).
        n_components : int
            Number of mixture components.
        covariance_type : str
            Covariance type, one of {'full', 'tied', 'diag', 'spherical'}.

        Returns
        -------
        float
            The AIC score (lower is better).
        """
        n_samples, n_features = X.shape
        if covariance_type == 'full':
            cov_params = n_components * n_features * (n_features + 1) / 2.0
        elif covariance_type == 'diag':
            cov_params = n_components * n_features
        elif covariance_type == 'tied':
            cov_params = n_features * (n_features + 1) / 2.0
        elif covariance_type == 'spherical':
            cov_params = n_components
        else:
            raise ValueError(f"Unsupported covariance type: {covariance_type}")

        mean_params = n_features * n_components
        weight_params = n_components - 1

        n_parameters = cov_params + mean_params + weight_params
        log_likelihood = lower_bound_ * n_samples
        aic = 2.0 * n_parameters - 2.0 * log_likelihood

        return aic

    @staticmethod
    def silhouette_score(
        X: torch.Tensor,
        labels: torch.Tensor,
        n_components: int
    ) -> float:
        r"""
        Compute the silhouette score for a partition of the data.

        $$ \mathrm{silhouette}(i) = \frac{b_i - a_i}{\max(a_i, b_i)} $$

        Where:
        - $a_i$ is the mean distance to points in the same cluster.
        - $b_i$ is the minimum mean distance to points in a different cluster.

        Parameters
        ----------
        X : torch.Tensor
            Data, shape (n_samples, n_features).
        labels : torch.Tensor
            Cluster labels, shape (n_samples,).
        n_components : int
            Number of clusters (must be >= 2).

        Returns
        -------
        float
            The mean silhouette score over all samples (range is typically [-1, 1],
            though it’s seldom negative in practice if distances are Euclidean).
        """
        assert n_components > 1, "Silhouette score is only defined when there are at least 2 clusters."

        labels = labels.to(X.device)
        distances = torch.cdist(X, X)  # (n_samples, n_samples)

        A = torch.zeros(labels.size(0), dtype=torch.float, device=X.device)
        B = torch.full((labels.size(0),), float('inf'), dtype=torch.float, device=X.device)

        # For each cluster i, compute intra-cluster distances and inter-cluster distances
        for i in range(n_components):
            mask_i = (labels == i)
            if mask_i.sum() <= 1:
                continue

            intra_cluster_distances = distances[mask_i][:, mask_i]
            # a_i: average distance within the same cluster
            A[mask_i] = intra_cluster_distances.sum(dim=1) / (mask_i.sum() - 1)

            # b_i: minimum distance to any other cluster
            for j in range(n_components):
                if i == j:
                    continue
                mask_j = (labels == j)
                if mask_j.sum() == 0:
                    continue
                inter_cluster_distances = distances[mask_i][:, mask_j]
                B[mask_i] = torch.min(B[mask_i], inter_cluster_distances.mean(dim=1))

        silhouette_scores = (B - A) / torch.max(A, B)
        return silhouette_scores.mean().item()

    @staticmethod
    def davies_bouldin_index(X: torch.Tensor, labels: torch.Tensor, n_components: int) -> float:
        r"""
        Compute the Davies-Bouldin index (lower is better).

        $$ \text{DB} = \frac{1}{k} \sum_i \max_{j \neq i} \frac{S_i + S_j}{M_{ij}} $$

        Where:
        - $S_i$ is the average distance of points in cluster i to its centroid.
        - $M_{ij}$ is the distance between cluster centroids i and j.

        Parameters
        ----------
        X : torch.Tensor
            Data, shape (n_samples, n_features).
        labels : torch.Tensor
            Cluster labels, shape (n_samples,).
        n_components : int
            Number of clusters (must be >= 2).

        Returns
        -------
        float
            Davies-Bouldin index.
        """
        assert n_components > 1, "Davies-Bouldin index is only defined when >= 2 clusters."
        labels = labels.to(X.device)

        # Compute cluster centroids
        centroids = [X[labels == i].mean(dim=0) for i in range(n_components)]
        centroids = torch.stack(centroids, dim=0)

        # Distance matrix among centroids
        cluster_distances = torch.cdist(centroids, centroids)

        similarities = torch.zeros((n_components, n_components), device=X.device)

        for i in range(n_components):
            mask_i = (labels == i)
            dist_i = torch.norm(X[mask_i] - centroids[i], dim=1).mean()

            for j in range(n_components):
                if i == j:
                    continue
                mask_j = (labels == j)
                dist_j = torch.norm(X[mask_j] - centroids[j], dim=1).mean()
                similarities[i, j] = (dist_i + dist_j) / cluster_distances[i, j]

        db_index = torch.max(similarities, dim=1).values.mean()
        return db_index.item()

    @staticmethod
    def calinski_harabasz_score(X: torch.Tensor, labels: torch.Tensor, n_components: int) -> float:
        r"""
        Compute the Calinski-Harabasz index (a ratio of between-cluster dispersion
        to within-cluster dispersion).

        Parameters
        ----------
        X : torch.Tensor
            Data, shape (n_samples, n_features).
        labels : torch.Tensor
            Cluster labels.
        n_components : int
            Number of clusters.

        Returns
        -------
        float
            Calinski-Harabasz index (higher is better).
        """
        labels = labels.to(X.device)
        centroid_overall = X.mean(dim=0)

        # Cluster centroids
        centroids = [X[labels == i].mean(dim=0) for i in range(n_components)]
        centroids = torch.stack(centroids)

        # Between-cluster dispersion (SSB) & within-cluster (SSW)
        SSB = sum((labels == i).sum() * torch.norm(centroids[i] - centroid_overall).pow(2)
                  for i in range(n_components))
        SSW = sum(torch.norm(X[labels == i] - centroids[i], dim=1).pow(2).sum()
                  for i in range(n_components))

        n_samples = X.shape[0]
        CH = (SSB / (n_components - 1)) / (SSW / (n_samples - n_components))
        return CH.item()

    @staticmethod
    def dunn_index(X: torch.Tensor, labels: torch.Tensor, n_components: int) -> float:
        r"""
        Compute the Dunn index:

        $$ D = \frac{\min_{i \neq j} d(C_i, C_j)}{\max_k \mathrm{diam}(C_k)} $$

        Where:
        - $d(C_i, C_j)$ is the minimum distance between any points in clusters i, j.
        - $\mathrm{diam}(C_k)$ is the maximum distance between any points in cluster k.

        Higher Dunn index indicates better cluster separation.

        Parameters
        ----------
        X : torch.Tensor
            Data, shape (n_samples, n_features).
        labels : torch.Tensor
            Cluster labels.
        n_components : int
            Number of clusters.

        Returns
        -------
        float
            Dunn index (higher is better).
        """
        labels = labels.to(X.device)
        distances = torch.cdist(X, X)

        min_intercluster_dist = float('inf')
        max_intracluster_dist = 0.0

        for i in range(n_components):
            mask_i = (labels == i)
            if mask_i.sum() <= 1:
                continue

            intra_distances = distances[mask_i][:, mask_i]
            max_intracluster_dist = max(max_intracluster_dist, intra_distances.max().item())

            for j in range(i + 1, n_components):
                mask_j = (labels == j)
                if mask_j.sum() == 0:
                    continue
                inter_distances = distances[mask_i][:, mask_j]
                current_min = inter_distances.min().item()
                if current_min < min_intercluster_dist:
                    min_intercluster_dist = current_min

        dunn_index = (min_intercluster_dist / max_intracluster_dist) if max_intracluster_dist > 0 else 0.0
        return dunn_index

    # --------------------------
    # Supervised Metrics
    # --------------------------
    @staticmethod
    def rand_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        Rand Index (RI) measures the similarity between two clusterings.
        It counts the agreement of pairwise assignments.

        $ \text{RI} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}} $

        Parameters
        ----------
        labels_true : torch.Tensor
            Ground-truth labels, shape (n_samples,).
        labels_pred : torch.Tensor
            Predicted labels, shape (n_samples,).

        Returns
        -------
        float
            Rand Index in [0, 1].
        """
        device = labels_true.device
        n_samples = labels_true.size(0)

        # Build contingency matrix
        contingency = torch.zeros(
            (labels_true.max() + 1, labels_pred.max() + 1),
            dtype=torch.float, device=device
        )
        for i in range(n_samples):
            contingency[labels_true[i], labels_pred[i]] += 1

        sum_comb_c = torch.sum(contingency.pow(2) - contingency) / 2
        sum_comb = torch.sum(contingency.sum(dim=1).pow(2) - contingency.sum(dim=1)) / 2
        sum_comb_pred = torch.sum(contingency.sum(dim=0).pow(2) - contingency.sum(dim=0)) / 2

        tp = sum_comb_c
        fp = sum_comb_pred - tp
        fn = sum_comb - tp
        tn = n_samples * (n_samples - 1) / 2 - tp - fp - fn

        ri = (tp + tn) / (tp + fp + fn + tn)
        return ri.item()

    @staticmethod
    def adjusted_rand_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        Adjusted Rand Index (ARI), which adjusts RI for chance.

        $ \mathrm{ARI} = \frac{ \mathrm{RI} - \mathrm{E}[\mathrm{RI}] }{ \max(\mathrm{RI}) - \mathrm{E}[\mathrm{RI}] } $

        Parameters
        ----------
        labels_true : torch.Tensor
            Ground-truth labels.
        labels_pred : torch.Tensor
            Predicted labels.

        Returns
        -------
        float
            ARI in [-1, 1], though it’s typically in [0, 1].
        """
        device = labels_true.device
        n_samples = labels_true.size(0)

        # Build contingency matrix
        contingency = torch.zeros(
            (labels_true.max() + 1, labels_pred.max() + 1),
            dtype=torch.float, device=device
        )
        for i in range(n_samples):
            contingency[labels_true[i], labels_pred[i]] += 1

        sum_comb_c = torch.sum(contingency.pow(2) - contingency) / 2
        sum_comb = torch.sum(contingency.sum(dim=1).pow(2) - contingency.sum(dim=1)) / 2
        sum_comb_pred = torch.sum(contingency.sum(dim=0).pow(2) - contingency.sum(dim=0)) / 2

        expected_index = sum_comb * sum_comb_pred / (n_samples * (n_samples - 1) / 2)
        max_index = (sum_comb + sum_comb_pred) / 2
        rand_index = sum_comb_c

        ari = (rand_index - expected_index) / (max_index - expected_index)
        return ari.item()

    @staticmethod
    def mutual_info_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        Mutual Information (MI) between two clusterings.

        $ \mathrm{MI}(U, V) = \sum_{u \in U}\sum_{v \in V} p(u, v) \log\frac{p(u,v)}{p(u)p(v)} $

        Parameters
        ----------
        labels_true : torch.Tensor
            Ground-truth labels.
        labels_pred : torch.Tensor
            Predicted labels.

        Returns
        -------
        float
            Mutual information (>= 0).
        """
        device = labels_true.device
        contingency = torch.zeros(
            (labels_true.max() + 1, labels_pred.max() + 1),
            dtype=torch.float, device=device
        )
        for i in range(labels_true.size(0)):
            contingency[labels_true[i], labels_pred[i]] += 1

        contingency /= contingency.sum()
        outer = contingency.sum(dim=1, keepdim=True) * contingency.sum(dim=0, keepdim=True)
        nonzero = contingency > 0
        mi = (contingency[nonzero] *
              (torch.log(contingency[nonzero]) - torch.log(outer[nonzero]))).sum()
        return mi.item()

    @staticmethod
    def adjusted_mutual_info_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        Adjusted Mutual Information (AMI).

        Parameters
        ----------
        labels_true : torch.Tensor
            Ground-truth labels.
        labels_pred : torch.Tensor
            Predicted labels.

        Returns
        -------
        float
            AMI in [0, 1].
        """
        mi = ClusteringMetrics.mutual_info_score(labels_true, labels_pred)
        n_samples = labels_true.size(0)
        true_counts = torch.bincount(labels_true)
        pred_counts = torch.bincount(labels_pred)

        h_true = -torch.sum((true_counts / n_samples) *
                            torch.log(true_counts / n_samples + 1e-10))
        h_pred = -torch.sum((pred_counts / n_samples) *
                            torch.log(pred_counts / n_samples + 1e-10))

        # This "expected_mi" here is a rough approximation
        expected_mi = (h_true * h_pred) / n_samples
        ami = (mi - expected_mi) / (0.5 * (h_true + h_pred) - expected_mi)
        return ami.item()

    @staticmethod
    def normalized_mutual_info_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        Normalized Mutual Information (NMI) in [0, 1].

        $ \text{NMI} = \frac{2 \times \mathrm{MI}}{H(\text{true}) + H(\text{pred})} $

        Parameters
        ----------
        labels_true : torch.Tensor
            Ground-truth labels.
        labels_pred : torch.Tensor
            Predicted labels.

        Returns
        -------
        float
            NMI.
        """
        mi = ClusteringMetrics.mutual_info_score(labels_true, labels_pred)
        n_samples = labels_true.size(0)
        true_counts = torch.bincount(labels_true).float()
        pred_counts = torch.bincount(labels_pred).float()

        h_true = -torch.sum((true_counts / n_samples) *
                            torch.log(true_counts / n_samples + 1e-10))
        h_pred = -torch.sum((pred_counts / n_samples) *
                            torch.log(pred_counts / n_samples + 1e-10))

        nmi = 2.0 * mi / (h_true + h_pred)
        return nmi.item()

    @staticmethod
    def fowlkes_mallows_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        Fowlkes-Mallows index (FM) = $ \sqrt{ \text{precision} \times \text{recall} } $.

        Parameters
        ----------
        labels_true : torch.Tensor
            Ground-truth labels.
        labels_pred : torch.Tensor
            Predicted labels.

        Returns
        -------
        float
            FM in [0, 1].
        """
        device = labels_true.device
        n_samples = labels_true.size(0)
        contingency = torch.zeros(
            (labels_true.max() + 1, labels_pred.max() + 1),
            dtype=torch.float, device=device
        )
        for i in range(n_samples):
            contingency[labels_true[i], labels_pred[i]] += 1

        tp = torch.sum(contingency.pow(2)) - n_samples
        tp_fp = torch.sum(contingency.sum(dim=0).pow(2)) - n_samples
        tp_fn = torch.sum(contingency.sum(dim=1).pow(2)) - n_samples

        fm = torch.sqrt(tp / tp_fp * tp / tp_fn)
        return fm.item()

    @staticmethod
    def completeness_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        Completeness score measures how much each cluster contains only samples
        of a single class.

        Parameters
        ----------
        labels_true : torch.Tensor
            Ground-truth labels.
        labels_pred : torch.Tensor
            Predicted labels.

        Returns
        -------
        float
            Completeness in [0, 1].
        """
        device = labels_true.device
        n_samples = labels_true.size(0)
        contingency = torch.zeros(
            (labels_true.max() + 1, labels_pred.max() + 1),
            dtype=torch.float, device=device
        )
        for i in range(n_samples):
            contingency[labels_true[i], labels_pred[i]] += 1

        entropy_true = -torch.sum(labels_true.bincount().float() / n_samples *
                                  torch.log(labels_true.bincount().float() / n_samples + 1e-10))

        # H(C | K) (conditional entropy)
        entropy_cond = -torch.sum(contingency / n_samples *
                                  torch.log((contingency + 1e-10) /
                                            contingency.sum(dim=1, keepdim=True)))

        comp_score = 1 - entropy_cond / entropy_true
        return comp_score.item()

    @staticmethod
    def homogeneity_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        Homogeneity score measures if each cluster contains samples of only one class.

        Parameters
        ----------
        labels_true : torch.Tensor
        labels_pred : torch.Tensor

        Returns
        -------
        float
            Homogeneity in [0, 1].
        """
        device = labels_true.device
        n_samples = labels_true.size(0)
        contingency = torch.zeros(
            (labels_true.max() + 1, labels_pred.max() + 1),
            dtype=torch.float, device=device
        )
        for i in range(n_samples):
            contingency[labels_true[i], labels_pred[i]] += 1

        entropy_pred = -torch.sum(labels_pred.bincount().float() / n_samples *
                                  torch.log(labels_pred.bincount().float() / n_samples + 1e-10))

        # H(K | C)
        entropy_cond = -torch.sum(contingency / n_samples *
                                  torch.log((contingency + 1e-10) /
                                            contingency.sum(dim=0, keepdim=True)))

        hom_score = 1 - entropy_cond / entropy_pred
        return hom_score.item()

    @staticmethod
    def v_measure_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        V-measure = $ \frac{2 \times (\text{homogeneity} \times \text{completeness})}
                           {\text{homogeneity} + \text{completeness}} $.

        Parameters
        ----------
        labels_true : torch.Tensor
        labels_pred : torch.Tensor

        Returns
        -------
        float
            V-measure in [0, 1].
        """
        homogeneity = ClusteringMetrics.homogeneity_score(labels_true, labels_pred)
        completeness = ClusteringMetrics.completeness_score(labels_true, labels_pred)
        if (homogeneity + completeness) == 0:
            return 0.0
        return 2.0 * (homogeneity * completeness) / (homogeneity + completeness)

    @staticmethod
    def purity_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        Purity score measures how many samples belong to the correct cluster.

        $ \mathrm{purity} = \frac{1}{N} \sum_k \max_j \lvert C_k \cap U_j \rvert $

        Parameters
        ----------
        labels_true : torch.Tensor
        labels_pred : torch.Tensor

        Returns
        -------
        float
            Purity in [0, 1].
        """
        device = labels_true.device
        n_samples = labels_true.size(0)
        contingency = torch.zeros(
            (labels_true.max() + 1, labels_pred.max() + 1),
            dtype=torch.float, device=device
        )
        for i in range(n_samples):
            contingency[labels_true[i], labels_pred[i]] += 1

        purity = torch.sum(torch.max(contingency, dim=0).values) / n_samples
        return purity.item()

    @staticmethod
    def confusion_matrix(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> torch.Tensor:
        r"""
        Compute a confusion matrix.

        Parameters
        ----------
        labels_true : torch.Tensor
        labels_pred : torch.Tensor

        Returns
        -------
        torch.Tensor
            A 2D (C x C) matrix, where C is the number of unique labels.
        """
        unique_labels = torch.unique(labels_true)
        num_labels = unique_labels.size(0)
        cm = torch.zeros((num_labels, num_labels), dtype=torch.int32)

        for i, label_true in enumerate(unique_labels):
            for j, label_pred in enumerate(unique_labels):
                cm[i, j] = ((labels_true == label_true) & (labels_pred == label_pred)).sum().item()

        return cm

    @staticmethod
    def classification_report(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> dict:
        r"""
        Compute a simple classification report for each class: precision, recall, F1,
        Jaccard index, ROC-AUC, and support.

        Parameters
        ----------
        labels_true : torch.Tensor
        labels_pred : torch.Tensor

        Returns
        -------
        dict
            A dictionary keyed by class label, each containing:
            "precision", "recall", "f1-score", "support", "jaccard", "roc_auc".
        """
        device = labels_true.device
        unique_labels = torch.unique(labels_true)
        report = {}

        for label in unique_labels:
            true_positives = ((labels_true == label) & (labels_pred == label)).sum().item()
            false_positives = ((labels_true != label) & (labels_pred == label)).sum().item()
            false_negatives = ((labels_true == label) & (labels_pred != label)).sum().item()

            precision = (true_positives / (true_positives + false_positives)
                         if (true_positives + false_positives) > 0 else 0.0)
            recall = (true_positives / (true_positives + false_negatives)
                      if (true_positives + false_negatives) > 0 else 0.0)
            f1_score = (2.0 * (precision * recall) / (precision + recall)
                        if (precision + recall) > 0 else 0.0)
            support = (labels_true == label).sum().item()
            jaccard_index = (true_positives / (true_positives + false_positives + false_negatives)
                             if (true_positives + false_positives + false_negatives) > 0 else 0.0)

            binary_true = (labels_true == label).float().to(device)
            binary_pred = (labels_pred == label).float().to(device)
            roc_auc = ClusteringMetrics.roc_auc_score(binary_true, binary_pred)

            report[int(label)] = {
                "precision": precision,
                "recall": recall,
                "f1-score": f1_score,
                "support": support,
                "jaccard": jaccard_index,
                "roc_auc": roc_auc
            }

        return report

    @staticmethod
    def roc_auc_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
        r"""
        Compute a naive ROC-AUC for binary predictions (0 or 1).
        If all labels are the same class, returns 1.0 by definition.

        Parameters
        ----------
        labels_true : torch.Tensor
            Binary ground-truth (0 or 1), shape (n_samples,).
        labels_pred : torch.Tensor
            Binary predictions or real-valued probabilities, shape (n_samples,).

        Returns
        -------
        float
            Area Under the ROC Curve (AUC).
        """
        if labels_true.sum() == 0 or labels_true.sum() == labels_true.size(0):
            # Degenerate case: all positives or all negatives => AUC = 1.0 or undefined
            return 1.0

        sorted_indices = torch.argsort(labels_pred, descending=True)
        labels_true = labels_true[sorted_indices]

        tpr = torch.cumsum(labels_true, dim=0) / labels_true.sum()
        fpr = torch.cumsum(1 - labels_true, dim=0) / (labels_true.size(0) - labels_true.sum())

        auc = torch.trapz(tpr, fpr)
        return auc.item()

    @staticmethod
    def kl_divergence_gmm(gmm_p, gmm_q, n_samples: int = 10000) -> float:
        r"""
        Approximate the KL divergence $D_{KL}(p \Vert q)$ between two
        Gaussian Mixture Models using Monte Carlo sampling from ``gmm_p``.

        Parameters
        ----------
        gmm_p : GaussianMixture
            The first GMM (interpreted as distribution p).
        gmm_q : GaussianMixture
            The second GMM (interpreted as distribution q).
        n_samples : int, optional
            Number of samples to draw from gmm_p for the Monte Carlo approximation
            (default: 10000).

        Returns
        -------
        float
            Approximated KL divergence $ \mathrm{E}_{x \sim p}[\log p(x) - \log q(x)]$.
        """
        device = gmm_p.device
        samples, _ = gmm_p.sample(n_samples)  # (n_samples, n_features)
        samples = samples.to(device)

        # Log-likelihood of samples under both GMMs
        log_p = gmm_p.score_samples(samples)
        log_q = gmm_q.score_samples(samples)

        kl_divergence = (log_p - log_q).mean().item()
        return kl_divergence

    @staticmethod
    def evaluate_clustering(
        gmm_model,
        X: torch.Tensor,
        true_labels: Optional[torch.Tensor] = None,
        metrics: Optional[list] = None
    ) -> dict:
        r"""
        Evaluate a fitted GMM using a set of requested metrics, possibly with ground-truth labels.

        Parameters
        ----------
        gmm_model : GaussianMixture
            A fitted GMM model (must have ``fitted_ == True``).
        X : torch.Tensor
            Data to evaluate, shape (n_samples, n_features).
        true_labels : torch.Tensor or None
            Ground-truth labels for supervised metrics (optional).
        metrics : list of str or None
            Which metrics to compute. If None, uses a default set.

        Returns
        -------
        results : dict
            A dictionary of metric_name -> metric_value pairs.
        """
        if not gmm_model.fitted_:
            raise ValueError("The GMM model must be fitted before evaluation.")

        if metrics is None:
            metrics = [
                # Supervised
                "rand_score", "adjusted_rand_score", "mutual_info_score",
                "normalized_mutual_info_score", "adjusted_mutual_info_score",
                "fowlkes_mallows_score", "homogeneity_score", "completeness_score",
                "v_measure_score", "purity_score",
                # Classification-based
                "classification_report", "confusion_matrix",
                # Unsupervised
                "silhouette_score", "davies_bouldin_index", "calinski_harabasz_score",
                "dunn_index", "bic_score", "aic_score",
            ]

        # Predict cluster labels
        pred_labels = gmm_model.predict(X).cpu()
        results = {}

        # If ground-truth labels provided, compute supervised metrics
        if true_labels is not None:
            true_labels = true_labels.cpu()

            if "rand_score" in metrics:
                results["rand_score"] = ClusteringMetrics.rand_score(true_labels, pred_labels)
            if "adjusted_rand_score" in metrics:
                results["adjusted_rand_score"] = ClusteringMetrics.adjusted_rand_score(true_labels, pred_labels)
            if "mutual_info_score" in metrics:
                results["mutual_info_score"] = ClusteringMetrics.mutual_info_score(true_labels, pred_labels)
            if "adjusted_mutual_info_score" in metrics:
                results["adjusted_mutual_info_score"] = ClusteringMetrics.adjusted_mutual_info_score(true_labels, pred_labels)
            if "normalized_mutual_info_score" in metrics:
                results["normalized_mutual_info_score"] = ClusteringMetrics.normalized_mutual_info_score(true_labels, pred_labels)
            if "fowlkes_mallows_score" in metrics:
                results["fowlkes_mallows_score"] = ClusteringMetrics.fowlkes_mallows_score(true_labels, pred_labels)
            if "homogeneity_score" in metrics:
                results["homogeneity_score"] = ClusteringMetrics.homogeneity_score(true_labels, pred_labels)
            if "completeness_score" in metrics:
                results["completeness_score"] = ClusteringMetrics.completeness_score(true_labels, pred_labels)
            if "v_measure_score" in metrics:
                results["v_measure_score"] = ClusteringMetrics.v_measure_score(true_labels, pred_labels)
            if "purity_score" in metrics:
                results["purity_score"] = ClusteringMetrics.purity_score(true_labels, pred_labels)
            if "classification_report" in metrics:
                results["classification_report"] = ClusteringMetrics.classification_report(true_labels, pred_labels)
            if "confusion_matrix" in metrics:
                results["confusion_matrix"] = ClusteringMetrics.confusion_matrix(true_labels, pred_labels)

        # Unsupervised metrics
        unique_pred_labels = torch.unique(pred_labels)
        if len(unique_pred_labels) > 1:
            if "silhouette_score" in metrics:
                results["silhouette_score"] = ClusteringMetrics.silhouette_score(
                    X.cpu(), pred_labels, gmm_model.n_components
                )
            if "davies_bouldin_index" in metrics:
                results["davies_bouldin_index"] = ClusteringMetrics.davies_bouldin_index(
                    X.cpu(), pred_labels, gmm_model.n_components
                )
            if "calinski_harabasz_score" in metrics:
                results["calinski_harabasz_score"] = ClusteringMetrics.calinski_harabasz_score(
                    X.cpu(), pred_labels, gmm_model.n_components
                )
            if "dunn_index" in metrics:
                results["dunn_index"] = ClusteringMetrics.dunn_index(
                    X.cpu(), pred_labels, gmm_model.n_components
                )

        # Info criteria
        if "bic_score" in metrics:
            results["bic_score"] = ClusteringMetrics.bic_score(
                gmm_model.lower_bound_,
                X.cpu(),
                gmm_model.n_components,
                gmm_model.covariance_type
            )
        if "aic_score" in metrics:
            results["aic_score"] = ClusteringMetrics.aic_score(
                gmm_model.lower_bound_,
                X.cpu(),
                gmm_model.n_components,
                gmm_model.covariance_type
            )

        return results

Functions

bic_score(lower_bound_, X, n_components, covariance_type) staticmethod

Compute the Bayesian Information Criterion (BIC) for a GMM given its average log-likelihood (lower bound).

$ \text{BIC} = n_{\text{params}} \ln(n_{\text{samples}}) - 2 \times \text{log_likelihood} $

Parameters:

Name Type Description Default
lower_bound_ float

Average (per-sample) log-likelihood or lower bound from the GMM.

required
X Tensor

Data used in fitting, shape (n_samples, n_features).

required
n_components int

Number of mixture components in the GMM.

required
covariance_type str

Covariance type, one of {'full', 'tied', 'diag', 'spherical'}.

required

Returns:

Type Description
float

The BIC score (lower is better).

Source code in tgmm/metrics.py
@staticmethod
def bic_score(
    lower_bound_: float,
    X: torch.Tensor,
    n_components: int,
    covariance_type: str
) -> float:
    r"""
    Compute the Bayesian Information Criterion (BIC) for a GMM given its
    average log-likelihood (lower bound).

    $ \text{BIC} = n_{\text{params}} \ln(n_{\text{samples}}) - 2 \times \text{log\_likelihood} $

    Parameters
    ----------
    lower_bound_ : float
        Average (per-sample) log-likelihood or lower bound from the GMM.
    X : torch.Tensor
        Data used in fitting, shape (n_samples, n_features).
    n_components : int
        Number of mixture components in the GMM.
    covariance_type : str
        Covariance type, one of {'full', 'tied', 'diag', 'spherical'}.

    Returns
    -------
    float
        The BIC score (lower is better).
    """
    n_samples, n_features = X.shape

    # Determine number of free parameters in covariance
    if covariance_type == 'full':
        cov_params = n_components * n_features * (n_features + 1) / 2.0
    elif covariance_type == 'diag':
        cov_params = n_components * n_features
    elif covariance_type == 'tied':
        cov_params = n_features * (n_features + 1) / 2.0
    elif covariance_type == 'spherical':
        cov_params = n_components
    else:
        raise ValueError(f"Unsupported covariance type: {covariance_type}")

    # Means + weights
    mean_params = n_features * n_components
    weight_params = n_components - 1

    n_parameters = cov_params + mean_params + weight_params
    log_likelihood = lower_bound_ * n_samples  # total log-likelihood

    bic = n_parameters * torch.log(torch.tensor(n_samples, dtype=torch.float)) - 2.0 * log_likelihood
    return bic.item()

aic_score(lower_bound_, X, n_components, covariance_type) staticmethod

Compute the Akaike Information Criterion (AIC) for a GMM.

$ \text{AIC} = 2 \times n_{\text{params}} - 2 \times \text{log_likelihood} $

Parameters:

Name Type Description Default
lower_bound_ float

Average (per-sample) log-likelihood from the GMM.

required
X Tensor

Data used in fitting, shape (n_samples, n_features).

required
n_components int

Number of mixture components.

required
covariance_type str

Covariance type, one of {'full', 'tied', 'diag', 'spherical'}.

required

Returns:

Type Description
float

The AIC score (lower is better).

Source code in tgmm/metrics.py
@staticmethod
def aic_score(
    lower_bound_: float,
    X: torch.Tensor,
    n_components: int,
    covariance_type: str
) -> float:
    r"""
    Compute the Akaike Information Criterion (AIC) for a GMM.

    $ \text{AIC} = 2 \times n_{\text{params}} - 2 \times \text{log\_likelihood} $

    Parameters
    ----------
    lower_bound_ : float
        Average (per-sample) log-likelihood from the GMM.
    X : torch.Tensor
        Data used in fitting, shape (n_samples, n_features).
    n_components : int
        Number of mixture components.
    covariance_type : str
        Covariance type, one of {'full', 'tied', 'diag', 'spherical'}.

    Returns
    -------
    float
        The AIC score (lower is better).
    """
    n_samples, n_features = X.shape
    if covariance_type == 'full':
        cov_params = n_components * n_features * (n_features + 1) / 2.0
    elif covariance_type == 'diag':
        cov_params = n_components * n_features
    elif covariance_type == 'tied':
        cov_params = n_features * (n_features + 1) / 2.0
    elif covariance_type == 'spherical':
        cov_params = n_components
    else:
        raise ValueError(f"Unsupported covariance type: {covariance_type}")

    mean_params = n_features * n_components
    weight_params = n_components - 1

    n_parameters = cov_params + mean_params + weight_params
    log_likelihood = lower_bound_ * n_samples
    aic = 2.0 * n_parameters - 2.0 * log_likelihood

    return aic

silhouette_score(X, labels, n_components) staticmethod

Compute the silhouette score for a partition of the data.

\[ \mathrm{silhouette}(i) = \frac{b_i - a_i}{\max(a_i, b_i)} \]

Where: - \(a_i\) is the mean distance to points in the same cluster. - \(b_i\) is the minimum mean distance to points in a different cluster.

Parameters:

Name Type Description Default
X Tensor

Data, shape (n_samples, n_features).

required
labels Tensor

Cluster labels, shape (n_samples,).

required
n_components int

Number of clusters (must be >= 2).

required

Returns:

Type Description
float

The mean silhouette score over all samples (range is typically [-1, 1], though it’s seldom negative in practice if distances are Euclidean).

Source code in tgmm/metrics.py
@staticmethod
def silhouette_score(
    X: torch.Tensor,
    labels: torch.Tensor,
    n_components: int
) -> float:
    r"""
    Compute the silhouette score for a partition of the data.

    $$ \mathrm{silhouette}(i) = \frac{b_i - a_i}{\max(a_i, b_i)} $$

    Where:
    - $a_i$ is the mean distance to points in the same cluster.
    - $b_i$ is the minimum mean distance to points in a different cluster.

    Parameters
    ----------
    X : torch.Tensor
        Data, shape (n_samples, n_features).
    labels : torch.Tensor
        Cluster labels, shape (n_samples,).
    n_components : int
        Number of clusters (must be >= 2).

    Returns
    -------
    float
        The mean silhouette score over all samples (range is typically [-1, 1],
        though it’s seldom negative in practice if distances are Euclidean).
    """
    assert n_components > 1, "Silhouette score is only defined when there are at least 2 clusters."

    labels = labels.to(X.device)
    distances = torch.cdist(X, X)  # (n_samples, n_samples)

    A = torch.zeros(labels.size(0), dtype=torch.float, device=X.device)
    B = torch.full((labels.size(0),), float('inf'), dtype=torch.float, device=X.device)

    # For each cluster i, compute intra-cluster distances and inter-cluster distances
    for i in range(n_components):
        mask_i = (labels == i)
        if mask_i.sum() <= 1:
            continue

        intra_cluster_distances = distances[mask_i][:, mask_i]
        # a_i: average distance within the same cluster
        A[mask_i] = intra_cluster_distances.sum(dim=1) / (mask_i.sum() - 1)

        # b_i: minimum distance to any other cluster
        for j in range(n_components):
            if i == j:
                continue
            mask_j = (labels == j)
            if mask_j.sum() == 0:
                continue
            inter_cluster_distances = distances[mask_i][:, mask_j]
            B[mask_i] = torch.min(B[mask_i], inter_cluster_distances.mean(dim=1))

    silhouette_scores = (B - A) / torch.max(A, B)
    return silhouette_scores.mean().item()

davies_bouldin_index(X, labels, n_components) staticmethod

Compute the Davies-Bouldin index (lower is better).

\[ \text{DB} = \frac{1}{k} \sum_i \max_{j \neq i} \frac{S_i + S_j}{M_{ij}} \]

Where: - \(S_i\) is the average distance of points in cluster i to its centroid. - \(M_{ij}\) is the distance between cluster centroids i and j.

Parameters:

Name Type Description Default
X Tensor

Data, shape (n_samples, n_features).

required
labels Tensor

Cluster labels, shape (n_samples,).

required
n_components int

Number of clusters (must be >= 2).

required

Returns:

Type Description
float

Davies-Bouldin index.

Source code in tgmm/metrics.py
@staticmethod
def davies_bouldin_index(X: torch.Tensor, labels: torch.Tensor, n_components: int) -> float:
    r"""
    Compute the Davies-Bouldin index (lower is better).

    $$ \text{DB} = \frac{1}{k} \sum_i \max_{j \neq i} \frac{S_i + S_j}{M_{ij}} $$

    Where:
    - $S_i$ is the average distance of points in cluster i to its centroid.
    - $M_{ij}$ is the distance between cluster centroids i and j.

    Parameters
    ----------
    X : torch.Tensor
        Data, shape (n_samples, n_features).
    labels : torch.Tensor
        Cluster labels, shape (n_samples,).
    n_components : int
        Number of clusters (must be >= 2).

    Returns
    -------
    float
        Davies-Bouldin index.
    """
    assert n_components > 1, "Davies-Bouldin index is only defined when >= 2 clusters."
    labels = labels.to(X.device)

    # Compute cluster centroids
    centroids = [X[labels == i].mean(dim=0) for i in range(n_components)]
    centroids = torch.stack(centroids, dim=0)

    # Distance matrix among centroids
    cluster_distances = torch.cdist(centroids, centroids)

    similarities = torch.zeros((n_components, n_components), device=X.device)

    for i in range(n_components):
        mask_i = (labels == i)
        dist_i = torch.norm(X[mask_i] - centroids[i], dim=1).mean()

        for j in range(n_components):
            if i == j:
                continue
            mask_j = (labels == j)
            dist_j = torch.norm(X[mask_j] - centroids[j], dim=1).mean()
            similarities[i, j] = (dist_i + dist_j) / cluster_distances[i, j]

    db_index = torch.max(similarities, dim=1).values.mean()
    return db_index.item()

calinski_harabasz_score(X, labels, n_components) staticmethod

Compute the Calinski-Harabasz index (a ratio of between-cluster dispersion to within-cluster dispersion).

Parameters:

Name Type Description Default
X Tensor

Data, shape (n_samples, n_features).

required
labels Tensor

Cluster labels.

required
n_components int

Number of clusters.

required

Returns:

Type Description
float

Calinski-Harabasz index (higher is better).

Source code in tgmm/metrics.py
@staticmethod
def calinski_harabasz_score(X: torch.Tensor, labels: torch.Tensor, n_components: int) -> float:
    r"""
    Compute the Calinski-Harabasz index (a ratio of between-cluster dispersion
    to within-cluster dispersion).

    Parameters
    ----------
    X : torch.Tensor
        Data, shape (n_samples, n_features).
    labels : torch.Tensor
        Cluster labels.
    n_components : int
        Number of clusters.

    Returns
    -------
    float
        Calinski-Harabasz index (higher is better).
    """
    labels = labels.to(X.device)
    centroid_overall = X.mean(dim=0)

    # Cluster centroids
    centroids = [X[labels == i].mean(dim=0) for i in range(n_components)]
    centroids = torch.stack(centroids)

    # Between-cluster dispersion (SSB) & within-cluster (SSW)
    SSB = sum((labels == i).sum() * torch.norm(centroids[i] - centroid_overall).pow(2)
              for i in range(n_components))
    SSW = sum(torch.norm(X[labels == i] - centroids[i], dim=1).pow(2).sum()
              for i in range(n_components))

    n_samples = X.shape[0]
    CH = (SSB / (n_components - 1)) / (SSW / (n_samples - n_components))
    return CH.item()

dunn_index(X, labels, n_components) staticmethod

Compute the Dunn index:

\[ D = \frac{\min_{i \neq j} d(C_i, C_j)}{\max_k \mathrm{diam}(C_k)} \]

Where: - \(d(C_i, C_j)\) is the minimum distance between any points in clusters i, j. - \(\mathrm{diam}(C_k)\) is the maximum distance between any points in cluster k.

Higher Dunn index indicates better cluster separation.

Parameters:

Name Type Description Default
X Tensor

Data, shape (n_samples, n_features).

required
labels Tensor

Cluster labels.

required
n_components int

Number of clusters.

required

Returns:

Type Description
float

Dunn index (higher is better).

Source code in tgmm/metrics.py
@staticmethod
def dunn_index(X: torch.Tensor, labels: torch.Tensor, n_components: int) -> float:
    r"""
    Compute the Dunn index:

    $$ D = \frac{\min_{i \neq j} d(C_i, C_j)}{\max_k \mathrm{diam}(C_k)} $$

    Where:
    - $d(C_i, C_j)$ is the minimum distance between any points in clusters i, j.
    - $\mathrm{diam}(C_k)$ is the maximum distance between any points in cluster k.

    Higher Dunn index indicates better cluster separation.

    Parameters
    ----------
    X : torch.Tensor
        Data, shape (n_samples, n_features).
    labels : torch.Tensor
        Cluster labels.
    n_components : int
        Number of clusters.

    Returns
    -------
    float
        Dunn index (higher is better).
    """
    labels = labels.to(X.device)
    distances = torch.cdist(X, X)

    min_intercluster_dist = float('inf')
    max_intracluster_dist = 0.0

    for i in range(n_components):
        mask_i = (labels == i)
        if mask_i.sum() <= 1:
            continue

        intra_distances = distances[mask_i][:, mask_i]
        max_intracluster_dist = max(max_intracluster_dist, intra_distances.max().item())

        for j in range(i + 1, n_components):
            mask_j = (labels == j)
            if mask_j.sum() == 0:
                continue
            inter_distances = distances[mask_i][:, mask_j]
            current_min = inter_distances.min().item()
            if current_min < min_intercluster_dist:
                min_intercluster_dist = current_min

    dunn_index = (min_intercluster_dist / max_intracluster_dist) if max_intracluster_dist > 0 else 0.0
    return dunn_index

rand_score(labels_true, labels_pred) staticmethod

Rand Index (RI) measures the similarity between two clusterings. It counts the agreement of pairwise assignments.

$ \text{RI} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}} $

Parameters:

Name Type Description Default
labels_true Tensor

Ground-truth labels, shape (n_samples,).

required
labels_pred Tensor

Predicted labels, shape (n_samples,).

required

Returns:

Type Description
float

Rand Index in [0, 1].

Source code in tgmm/metrics.py
@staticmethod
def rand_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    Rand Index (RI) measures the similarity between two clusterings.
    It counts the agreement of pairwise assignments.

    $ \text{RI} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}} $

    Parameters
    ----------
    labels_true : torch.Tensor
        Ground-truth labels, shape (n_samples,).
    labels_pred : torch.Tensor
        Predicted labels, shape (n_samples,).

    Returns
    -------
    float
        Rand Index in [0, 1].
    """
    device = labels_true.device
    n_samples = labels_true.size(0)

    # Build contingency matrix
    contingency = torch.zeros(
        (labels_true.max() + 1, labels_pred.max() + 1),
        dtype=torch.float, device=device
    )
    for i in range(n_samples):
        contingency[labels_true[i], labels_pred[i]] += 1

    sum_comb_c = torch.sum(contingency.pow(2) - contingency) / 2
    sum_comb = torch.sum(contingency.sum(dim=1).pow(2) - contingency.sum(dim=1)) / 2
    sum_comb_pred = torch.sum(contingency.sum(dim=0).pow(2) - contingency.sum(dim=0)) / 2

    tp = sum_comb_c
    fp = sum_comb_pred - tp
    fn = sum_comb - tp
    tn = n_samples * (n_samples - 1) / 2 - tp - fp - fn

    ri = (tp + tn) / (tp + fp + fn + tn)
    return ri.item()

adjusted_rand_score(labels_true, labels_pred) staticmethod

Adjusted Rand Index (ARI), which adjusts RI for chance.

$ \mathrm{ARI} = \frac{ \mathrm{RI} - \mathrm{E}[\mathrm{RI}] }{ \max(\mathrm{RI}) - \mathrm{E}[\mathrm{RI}] } $

Parameters:

Name Type Description Default
labels_true Tensor

Ground-truth labels.

required
labels_pred Tensor

Predicted labels.

required

Returns:

Type Description
float

ARI in [-1, 1], though it’s typically in [0, 1].

Source code in tgmm/metrics.py
@staticmethod
def adjusted_rand_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    Adjusted Rand Index (ARI), which adjusts RI for chance.

    $ \mathrm{ARI} = \frac{ \mathrm{RI} - \mathrm{E}[\mathrm{RI}] }{ \max(\mathrm{RI}) - \mathrm{E}[\mathrm{RI}] } $

    Parameters
    ----------
    labels_true : torch.Tensor
        Ground-truth labels.
    labels_pred : torch.Tensor
        Predicted labels.

    Returns
    -------
    float
        ARI in [-1, 1], though it’s typically in [0, 1].
    """
    device = labels_true.device
    n_samples = labels_true.size(0)

    # Build contingency matrix
    contingency = torch.zeros(
        (labels_true.max() + 1, labels_pred.max() + 1),
        dtype=torch.float, device=device
    )
    for i in range(n_samples):
        contingency[labels_true[i], labels_pred[i]] += 1

    sum_comb_c = torch.sum(contingency.pow(2) - contingency) / 2
    sum_comb = torch.sum(contingency.sum(dim=1).pow(2) - contingency.sum(dim=1)) / 2
    sum_comb_pred = torch.sum(contingency.sum(dim=0).pow(2) - contingency.sum(dim=0)) / 2

    expected_index = sum_comb * sum_comb_pred / (n_samples * (n_samples - 1) / 2)
    max_index = (sum_comb + sum_comb_pred) / 2
    rand_index = sum_comb_c

    ari = (rand_index - expected_index) / (max_index - expected_index)
    return ari.item()

mutual_info_score(labels_true, labels_pred) staticmethod

Mutual Information (MI) between two clusterings.

$ \mathrm{MI}(U, V) = \sum_{u \in U}\sum_{v \in V} p(u, v) \log\frac{p(u,v)}{p(u)p(v)} $

Parameters:

Name Type Description Default
labels_true Tensor

Ground-truth labels.

required
labels_pred Tensor

Predicted labels.

required

Returns:

Type Description
float

Mutual information (>= 0).

Source code in tgmm/metrics.py
@staticmethod
def mutual_info_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    Mutual Information (MI) between two clusterings.

    $ \mathrm{MI}(U, V) = \sum_{u \in U}\sum_{v \in V} p(u, v) \log\frac{p(u,v)}{p(u)p(v)} $

    Parameters
    ----------
    labels_true : torch.Tensor
        Ground-truth labels.
    labels_pred : torch.Tensor
        Predicted labels.

    Returns
    -------
    float
        Mutual information (>= 0).
    """
    device = labels_true.device
    contingency = torch.zeros(
        (labels_true.max() + 1, labels_pred.max() + 1),
        dtype=torch.float, device=device
    )
    for i in range(labels_true.size(0)):
        contingency[labels_true[i], labels_pred[i]] += 1

    contingency /= contingency.sum()
    outer = contingency.sum(dim=1, keepdim=True) * contingency.sum(dim=0, keepdim=True)
    nonzero = contingency > 0
    mi = (contingency[nonzero] *
          (torch.log(contingency[nonzero]) - torch.log(outer[nonzero]))).sum()
    return mi.item()

adjusted_mutual_info_score(labels_true, labels_pred) staticmethod

Adjusted Mutual Information (AMI).

Parameters:

Name Type Description Default
labels_true Tensor

Ground-truth labels.

required
labels_pred Tensor

Predicted labels.

required

Returns:

Type Description
float

AMI in [0, 1].

Source code in tgmm/metrics.py
@staticmethod
def adjusted_mutual_info_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    Adjusted Mutual Information (AMI).

    Parameters
    ----------
    labels_true : torch.Tensor
        Ground-truth labels.
    labels_pred : torch.Tensor
        Predicted labels.

    Returns
    -------
    float
        AMI in [0, 1].
    """
    mi = ClusteringMetrics.mutual_info_score(labels_true, labels_pred)
    n_samples = labels_true.size(0)
    true_counts = torch.bincount(labels_true)
    pred_counts = torch.bincount(labels_pred)

    h_true = -torch.sum((true_counts / n_samples) *
                        torch.log(true_counts / n_samples + 1e-10))
    h_pred = -torch.sum((pred_counts / n_samples) *
                        torch.log(pred_counts / n_samples + 1e-10))

    # This "expected_mi" here is a rough approximation
    expected_mi = (h_true * h_pred) / n_samples
    ami = (mi - expected_mi) / (0.5 * (h_true + h_pred) - expected_mi)
    return ami.item()

normalized_mutual_info_score(labels_true, labels_pred) staticmethod

Normalized Mutual Information (NMI) in [0, 1].

$ \text{NMI} = \frac{2 \times \mathrm{MI}}{H(\text{true}) + H(\text{pred})} $

Parameters:

Name Type Description Default
labels_true Tensor

Ground-truth labels.

required
labels_pred Tensor

Predicted labels.

required

Returns:

Type Description
float

NMI.

Source code in tgmm/metrics.py
@staticmethod
def normalized_mutual_info_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    Normalized Mutual Information (NMI) in [0, 1].

    $ \text{NMI} = \frac{2 \times \mathrm{MI}}{H(\text{true}) + H(\text{pred})} $

    Parameters
    ----------
    labels_true : torch.Tensor
        Ground-truth labels.
    labels_pred : torch.Tensor
        Predicted labels.

    Returns
    -------
    float
        NMI.
    """
    mi = ClusteringMetrics.mutual_info_score(labels_true, labels_pred)
    n_samples = labels_true.size(0)
    true_counts = torch.bincount(labels_true).float()
    pred_counts = torch.bincount(labels_pred).float()

    h_true = -torch.sum((true_counts / n_samples) *
                        torch.log(true_counts / n_samples + 1e-10))
    h_pred = -torch.sum((pred_counts / n_samples) *
                        torch.log(pred_counts / n_samples + 1e-10))

    nmi = 2.0 * mi / (h_true + h_pred)
    return nmi.item()

fowlkes_mallows_score(labels_true, labels_pred) staticmethod

Fowlkes-Mallows index (FM) = $ \sqrt{ \text{precision} \times \text{recall} } $.

Parameters:

Name Type Description Default
labels_true Tensor

Ground-truth labels.

required
labels_pred Tensor

Predicted labels.

required

Returns:

Type Description
float

FM in [0, 1].

Source code in tgmm/metrics.py
@staticmethod
def fowlkes_mallows_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    Fowlkes-Mallows index (FM) = $ \sqrt{ \text{precision} \times \text{recall} } $.

    Parameters
    ----------
    labels_true : torch.Tensor
        Ground-truth labels.
    labels_pred : torch.Tensor
        Predicted labels.

    Returns
    -------
    float
        FM in [0, 1].
    """
    device = labels_true.device
    n_samples = labels_true.size(0)
    contingency = torch.zeros(
        (labels_true.max() + 1, labels_pred.max() + 1),
        dtype=torch.float, device=device
    )
    for i in range(n_samples):
        contingency[labels_true[i], labels_pred[i]] += 1

    tp = torch.sum(contingency.pow(2)) - n_samples
    tp_fp = torch.sum(contingency.sum(dim=0).pow(2)) - n_samples
    tp_fn = torch.sum(contingency.sum(dim=1).pow(2)) - n_samples

    fm = torch.sqrt(tp / tp_fp * tp / tp_fn)
    return fm.item()

completeness_score(labels_true, labels_pred) staticmethod

Completeness score measures how much each cluster contains only samples of a single class.

Parameters:

Name Type Description Default
labels_true Tensor

Ground-truth labels.

required
labels_pred Tensor

Predicted labels.

required

Returns:

Type Description
float

Completeness in [0, 1].

Source code in tgmm/metrics.py
@staticmethod
def completeness_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    Completeness score measures how much each cluster contains only samples
    of a single class.

    Parameters
    ----------
    labels_true : torch.Tensor
        Ground-truth labels.
    labels_pred : torch.Tensor
        Predicted labels.

    Returns
    -------
    float
        Completeness in [0, 1].
    """
    device = labels_true.device
    n_samples = labels_true.size(0)
    contingency = torch.zeros(
        (labels_true.max() + 1, labels_pred.max() + 1),
        dtype=torch.float, device=device
    )
    for i in range(n_samples):
        contingency[labels_true[i], labels_pred[i]] += 1

    entropy_true = -torch.sum(labels_true.bincount().float() / n_samples *
                              torch.log(labels_true.bincount().float() / n_samples + 1e-10))

    # H(C | K) (conditional entropy)
    entropy_cond = -torch.sum(contingency / n_samples *
                              torch.log((contingency + 1e-10) /
                                        contingency.sum(dim=1, keepdim=True)))

    comp_score = 1 - entropy_cond / entropy_true
    return comp_score.item()

homogeneity_score(labels_true, labels_pred) staticmethod

Homogeneity score measures if each cluster contains samples of only one class.

Parameters:

Name Type Description Default
labels_true Tensor
required
labels_pred Tensor
required

Returns:

Type Description
float

Homogeneity in [0, 1].

Source code in tgmm/metrics.py
@staticmethod
def homogeneity_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    Homogeneity score measures if each cluster contains samples of only one class.

    Parameters
    ----------
    labels_true : torch.Tensor
    labels_pred : torch.Tensor

    Returns
    -------
    float
        Homogeneity in [0, 1].
    """
    device = labels_true.device
    n_samples = labels_true.size(0)
    contingency = torch.zeros(
        (labels_true.max() + 1, labels_pred.max() + 1),
        dtype=torch.float, device=device
    )
    for i in range(n_samples):
        contingency[labels_true[i], labels_pred[i]] += 1

    entropy_pred = -torch.sum(labels_pred.bincount().float() / n_samples *
                              torch.log(labels_pred.bincount().float() / n_samples + 1e-10))

    # H(K | C)
    entropy_cond = -torch.sum(contingency / n_samples *
                              torch.log((contingency + 1e-10) /
                                        contingency.sum(dim=0, keepdim=True)))

    hom_score = 1 - entropy_cond / entropy_pred
    return hom_score.item()

v_measure_score(labels_true, labels_pred) staticmethod

V-measure = $ \frac{2 \times (\text{homogeneity} \times \text{completeness})} {\text{homogeneity} + \text{completeness}} $.

Parameters:

Name Type Description Default
labels_true Tensor
required
labels_pred Tensor
required

Returns:

Type Description
float

V-measure in [0, 1].

Source code in tgmm/metrics.py
@staticmethod
def v_measure_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    V-measure = $ \frac{2 \times (\text{homogeneity} \times \text{completeness})}
                       {\text{homogeneity} + \text{completeness}} $.

    Parameters
    ----------
    labels_true : torch.Tensor
    labels_pred : torch.Tensor

    Returns
    -------
    float
        V-measure in [0, 1].
    """
    homogeneity = ClusteringMetrics.homogeneity_score(labels_true, labels_pred)
    completeness = ClusteringMetrics.completeness_score(labels_true, labels_pred)
    if (homogeneity + completeness) == 0:
        return 0.0
    return 2.0 * (homogeneity * completeness) / (homogeneity + completeness)

purity_score(labels_true, labels_pred) staticmethod

Purity score measures how many samples belong to the correct cluster.

$ \mathrm{purity} = \frac{1}{N} \sum_k \max_j \lvert C_k \cap U_j \rvert $

Parameters:

Name Type Description Default
labels_true Tensor
required
labels_pred Tensor
required

Returns:

Type Description
float

Purity in [0, 1].

Source code in tgmm/metrics.py
@staticmethod
def purity_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    Purity score measures how many samples belong to the correct cluster.

    $ \mathrm{purity} = \frac{1}{N} \sum_k \max_j \lvert C_k \cap U_j \rvert $

    Parameters
    ----------
    labels_true : torch.Tensor
    labels_pred : torch.Tensor

    Returns
    -------
    float
        Purity in [0, 1].
    """
    device = labels_true.device
    n_samples = labels_true.size(0)
    contingency = torch.zeros(
        (labels_true.max() + 1, labels_pred.max() + 1),
        dtype=torch.float, device=device
    )
    for i in range(n_samples):
        contingency[labels_true[i], labels_pred[i]] += 1

    purity = torch.sum(torch.max(contingency, dim=0).values) / n_samples
    return purity.item()

confusion_matrix(labels_true, labels_pred) staticmethod

Compute a confusion matrix.

Parameters:

Name Type Description Default
labels_true Tensor
required
labels_pred Tensor
required

Returns:

Type Description
Tensor

A 2D (C x C) matrix, where C is the number of unique labels.

Source code in tgmm/metrics.py
@staticmethod
def confusion_matrix(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> torch.Tensor:
    r"""
    Compute a confusion matrix.

    Parameters
    ----------
    labels_true : torch.Tensor
    labels_pred : torch.Tensor

    Returns
    -------
    torch.Tensor
        A 2D (C x C) matrix, where C is the number of unique labels.
    """
    unique_labels = torch.unique(labels_true)
    num_labels = unique_labels.size(0)
    cm = torch.zeros((num_labels, num_labels), dtype=torch.int32)

    for i, label_true in enumerate(unique_labels):
        for j, label_pred in enumerate(unique_labels):
            cm[i, j] = ((labels_true == label_true) & (labels_pred == label_pred)).sum().item()

    return cm

classification_report(labels_true, labels_pred) staticmethod

Compute a simple classification report for each class: precision, recall, F1, Jaccard index, ROC-AUC, and support.

Parameters:

Name Type Description Default
labels_true Tensor
required
labels_pred Tensor
required

Returns:

Type Description
dict

A dictionary keyed by class label, each containing: "precision", "recall", "f1-score", "support", "jaccard", "roc_auc".

Source code in tgmm/metrics.py
@staticmethod
def classification_report(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> dict:
    r"""
    Compute a simple classification report for each class: precision, recall, F1,
    Jaccard index, ROC-AUC, and support.

    Parameters
    ----------
    labels_true : torch.Tensor
    labels_pred : torch.Tensor

    Returns
    -------
    dict
        A dictionary keyed by class label, each containing:
        "precision", "recall", "f1-score", "support", "jaccard", "roc_auc".
    """
    device = labels_true.device
    unique_labels = torch.unique(labels_true)
    report = {}

    for label in unique_labels:
        true_positives = ((labels_true == label) & (labels_pred == label)).sum().item()
        false_positives = ((labels_true != label) & (labels_pred == label)).sum().item()
        false_negatives = ((labels_true == label) & (labels_pred != label)).sum().item()

        precision = (true_positives / (true_positives + false_positives)
                     if (true_positives + false_positives) > 0 else 0.0)
        recall = (true_positives / (true_positives + false_negatives)
                  if (true_positives + false_negatives) > 0 else 0.0)
        f1_score = (2.0 * (precision * recall) / (precision + recall)
                    if (precision + recall) > 0 else 0.0)
        support = (labels_true == label).sum().item()
        jaccard_index = (true_positives / (true_positives + false_positives + false_negatives)
                         if (true_positives + false_positives + false_negatives) > 0 else 0.0)

        binary_true = (labels_true == label).float().to(device)
        binary_pred = (labels_pred == label).float().to(device)
        roc_auc = ClusteringMetrics.roc_auc_score(binary_true, binary_pred)

        report[int(label)] = {
            "precision": precision,
            "recall": recall,
            "f1-score": f1_score,
            "support": support,
            "jaccard": jaccard_index,
            "roc_auc": roc_auc
        }

    return report

roc_auc_score(labels_true, labels_pred) staticmethod

Compute a naive ROC-AUC for binary predictions (0 or 1). If all labels are the same class, returns 1.0 by definition.

Parameters:

Name Type Description Default
labels_true Tensor

Binary ground-truth (0 or 1), shape (n_samples,).

required
labels_pred Tensor

Binary predictions or real-valued probabilities, shape (n_samples,).

required

Returns:

Type Description
float

Area Under the ROC Curve (AUC).

Source code in tgmm/metrics.py
@staticmethod
def roc_auc_score(labels_true: torch.Tensor, labels_pred: torch.Tensor) -> float:
    r"""
    Compute a naive ROC-AUC for binary predictions (0 or 1).
    If all labels are the same class, returns 1.0 by definition.

    Parameters
    ----------
    labels_true : torch.Tensor
        Binary ground-truth (0 or 1), shape (n_samples,).
    labels_pred : torch.Tensor
        Binary predictions or real-valued probabilities, shape (n_samples,).

    Returns
    -------
    float
        Area Under the ROC Curve (AUC).
    """
    if labels_true.sum() == 0 or labels_true.sum() == labels_true.size(0):
        # Degenerate case: all positives or all negatives => AUC = 1.0 or undefined
        return 1.0

    sorted_indices = torch.argsort(labels_pred, descending=True)
    labels_true = labels_true[sorted_indices]

    tpr = torch.cumsum(labels_true, dim=0) / labels_true.sum()
    fpr = torch.cumsum(1 - labels_true, dim=0) / (labels_true.size(0) - labels_true.sum())

    auc = torch.trapz(tpr, fpr)
    return auc.item()

kl_divergence_gmm(gmm_p, gmm_q, n_samples=10000) staticmethod

Approximate the KL divergence \(D_{KL}(p \Vert q)\) between two Gaussian Mixture Models using Monte Carlo sampling from gmm_p.

Parameters:

Name Type Description Default
gmm_p GaussianMixture

The first GMM (interpreted as distribution p).

required
gmm_q GaussianMixture

The second GMM (interpreted as distribution q).

required
n_samples int

Number of samples to draw from gmm_p for the Monte Carlo approximation (default: 10000).

10000

Returns:

Type Description
float

Approximated KL divergence $ \mathrm{E}_{x \sim p}[\log p(x) - \log q(x)]$.

Source code in tgmm/metrics.py
@staticmethod
def kl_divergence_gmm(gmm_p, gmm_q, n_samples: int = 10000) -> float:
    r"""
    Approximate the KL divergence $D_{KL}(p \Vert q)$ between two
    Gaussian Mixture Models using Monte Carlo sampling from ``gmm_p``.

    Parameters
    ----------
    gmm_p : GaussianMixture
        The first GMM (interpreted as distribution p).
    gmm_q : GaussianMixture
        The second GMM (interpreted as distribution q).
    n_samples : int, optional
        Number of samples to draw from gmm_p for the Monte Carlo approximation
        (default: 10000).

    Returns
    -------
    float
        Approximated KL divergence $ \mathrm{E}_{x \sim p}[\log p(x) - \log q(x)]$.
    """
    device = gmm_p.device
    samples, _ = gmm_p.sample(n_samples)  # (n_samples, n_features)
    samples = samples.to(device)

    # Log-likelihood of samples under both GMMs
    log_p = gmm_p.score_samples(samples)
    log_q = gmm_q.score_samples(samples)

    kl_divergence = (log_p - log_q).mean().item()
    return kl_divergence

evaluate_clustering(gmm_model, X, true_labels=None, metrics=None) staticmethod

Evaluate a fitted GMM using a set of requested metrics, possibly with ground-truth labels.

Parameters:

Name Type Description Default
gmm_model GaussianMixture

A fitted GMM model (must have fitted_ == True).

required
X Tensor

Data to evaluate, shape (n_samples, n_features).

required
true_labels Tensor or None

Ground-truth labels for supervised metrics (optional).

None
metrics list of str or None

Which metrics to compute. If None, uses a default set.

None

Returns:

Name Type Description
results dict

A dictionary of metric_name -> metric_value pairs.

Source code in tgmm/metrics.py
@staticmethod
def evaluate_clustering(
    gmm_model,
    X: torch.Tensor,
    true_labels: Optional[torch.Tensor] = None,
    metrics: Optional[list] = None
) -> dict:
    r"""
    Evaluate a fitted GMM using a set of requested metrics, possibly with ground-truth labels.

    Parameters
    ----------
    gmm_model : GaussianMixture
        A fitted GMM model (must have ``fitted_ == True``).
    X : torch.Tensor
        Data to evaluate, shape (n_samples, n_features).
    true_labels : torch.Tensor or None
        Ground-truth labels for supervised metrics (optional).
    metrics : list of str or None
        Which metrics to compute. If None, uses a default set.

    Returns
    -------
    results : dict
        A dictionary of metric_name -> metric_value pairs.
    """
    if not gmm_model.fitted_:
        raise ValueError("The GMM model must be fitted before evaluation.")

    if metrics is None:
        metrics = [
            # Supervised
            "rand_score", "adjusted_rand_score", "mutual_info_score",
            "normalized_mutual_info_score", "adjusted_mutual_info_score",
            "fowlkes_mallows_score", "homogeneity_score", "completeness_score",
            "v_measure_score", "purity_score",
            # Classification-based
            "classification_report", "confusion_matrix",
            # Unsupervised
            "silhouette_score", "davies_bouldin_index", "calinski_harabasz_score",
            "dunn_index", "bic_score", "aic_score",
        ]

    # Predict cluster labels
    pred_labels = gmm_model.predict(X).cpu()
    results = {}

    # If ground-truth labels provided, compute supervised metrics
    if true_labels is not None:
        true_labels = true_labels.cpu()

        if "rand_score" in metrics:
            results["rand_score"] = ClusteringMetrics.rand_score(true_labels, pred_labels)
        if "adjusted_rand_score" in metrics:
            results["adjusted_rand_score"] = ClusteringMetrics.adjusted_rand_score(true_labels, pred_labels)
        if "mutual_info_score" in metrics:
            results["mutual_info_score"] = ClusteringMetrics.mutual_info_score(true_labels, pred_labels)
        if "adjusted_mutual_info_score" in metrics:
            results["adjusted_mutual_info_score"] = ClusteringMetrics.adjusted_mutual_info_score(true_labels, pred_labels)
        if "normalized_mutual_info_score" in metrics:
            results["normalized_mutual_info_score"] = ClusteringMetrics.normalized_mutual_info_score(true_labels, pred_labels)
        if "fowlkes_mallows_score" in metrics:
            results["fowlkes_mallows_score"] = ClusteringMetrics.fowlkes_mallows_score(true_labels, pred_labels)
        if "homogeneity_score" in metrics:
            results["homogeneity_score"] = ClusteringMetrics.homogeneity_score(true_labels, pred_labels)
        if "completeness_score" in metrics:
            results["completeness_score"] = ClusteringMetrics.completeness_score(true_labels, pred_labels)
        if "v_measure_score" in metrics:
            results["v_measure_score"] = ClusteringMetrics.v_measure_score(true_labels, pred_labels)
        if "purity_score" in metrics:
            results["purity_score"] = ClusteringMetrics.purity_score(true_labels, pred_labels)
        if "classification_report" in metrics:
            results["classification_report"] = ClusteringMetrics.classification_report(true_labels, pred_labels)
        if "confusion_matrix" in metrics:
            results["confusion_matrix"] = ClusteringMetrics.confusion_matrix(true_labels, pred_labels)

    # Unsupervised metrics
    unique_pred_labels = torch.unique(pred_labels)
    if len(unique_pred_labels) > 1:
        if "silhouette_score" in metrics:
            results["silhouette_score"] = ClusteringMetrics.silhouette_score(
                X.cpu(), pred_labels, gmm_model.n_components
            )
        if "davies_bouldin_index" in metrics:
            results["davies_bouldin_index"] = ClusteringMetrics.davies_bouldin_index(
                X.cpu(), pred_labels, gmm_model.n_components
            )
        if "calinski_harabasz_score" in metrics:
            results["calinski_harabasz_score"] = ClusteringMetrics.calinski_harabasz_score(
                X.cpu(), pred_labels, gmm_model.n_components
            )
        if "dunn_index" in metrics:
            results["dunn_index"] = ClusteringMetrics.dunn_index(
                X.cpu(), pred_labels, gmm_model.n_components
            )

    # Info criteria
    if "bic_score" in metrics:
        results["bic_score"] = ClusteringMetrics.bic_score(
            gmm_model.lower_bound_,
            X.cpu(),
            gmm_model.n_components,
            gmm_model.covariance_type
        )
    if "aic_score" in metrics:
        results["aic_score"] = ClusteringMetrics.aic_score(
            gmm_model.lower_bound_,
            X.cpu(),
            gmm_model.n_components,
            gmm_model.covariance_type
        )

    return results