Clustering Metrics: Comprehensive Evaluation of Gaussian Mixture Models¶
This notebook provides a comprehensive evaluation of clustering performance using both unsupervised and supervised metrics. We demonstrate how to:
- Apply unsupervised metrics to determine the optimal number of components
- Compare supervised metrics with scikit-learn implementations
- Visualize clustering results with confusion matrices and classification reports
- Analyze KL divergence between different GMM models
The notebook covers both theoretical foundations (with mathematical definitions) and practical implementations of clustering evaluation metrics.
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
os.chdir('../')
from tgmm import GaussianMixture, ClusteringMetrics, dynamic_figsize, plot_gmm, match_predicted_to_true_labels
device = 'cuda' if torch.cuda.is_available() else 'cpu'
random_state = 42
np.random.seed(random_state)
torch.manual_seed(random_state)
if device == 'cuda':
torch.cuda.manual_seed(random_state)
print('CUDA version:', torch.version.cuda)
print('Device:', torch.cuda.get_device_name(0))
else:
print('Using CPU')
CUDA version: 12.4 Device: NVIDIA GeForce RTX 4060 Laptop GPU
Synthetic Data Generation¶
The synthetic dataset is generated by combining four Gaussian components:
- Component 1: Centered at
[0, 2]with spherical covariance. - Component 2: Centered at
[2, -2]with spherical covariance (fewer points). - Component 3: Centered at
[0, 0]with diagonal covariance. - Component 4: Centered at
[2, 2]with full covariance.
n_samples = [800, 200, 1000, 1000]
centers = [np.array([0, 2]),
np.array([2, -2]),
np.array([0, 0]),
np.array([2, 2])]
covs = [
1.0 * np.eye(2), # spherical covariance
0.5 * np.eye(2), # spherical covariance, fewer points
np.array([[2, 0], [0, 0.5]]), # diagonal covariance
np.array([[0.2, 0.5], [0.5, 2]]) # full covariance
]
components = []
for n, center, cov in zip(n_samples, centers, covs):
samples = np.dot(np.random.randn(n, 2), cov) + center
components.append(samples)
X = np.vstack(components)
labels = np.concatenate([i * np.ones(n) for i, n in enumerate(n_samples)])
legend_labels = [f'Component {i+1}' for i in range(len(n_samples))]
X_tensor = torch.tensor(X, dtype=torch.float32, device=device)
y_tensor = torch.tensor(labels, dtype=torch.long, device=device)
n_features = X.shape[1]
n_components = len(n_samples)
plot_gmm(X=X, true_labels=labels, title='Original Data', legend_labels=legend_labels)
plt.show()
Unsupervised Clustering Metrics Definitions¶
Below are the definitions of the unsupervised clustering metrics used in this thesis.
Silhouette Score¶
The Silhouette Score quantifies how similar each data point is to its own cluster compared to other clusters. For a data point $ \mathbf{x}_i $ belonging to cluster $ C_k $, the silhouette value is defined as
$$ s(i) = \frac{b(i) - a(i)}{\max\{a(i), b(i)\}}, $$
where:
- $ a(i) $ is the average distance between $\mathbf{x}_i$ and all other points in the same cluster $ C_k $ (i.e., the intra-cluster distance),
- $ b(i) $ is the smallest average distance between $\mathbf{x}_i$ and all points in any other cluster (i.e., the nearest-cluster distance).
A high silhouette score (close to 1) indicates that the data point is well matched to its own cluster and poorly matched to neighboring clusters.
Davies-Bouldin Index¶
The Davies-Bouldin Index (DB) measures the average similarity between each cluster and its most similar one. It is defined as
$$ \text{DB} = \frac{1}{k} \sum_{i=1}^{k} \max_{j \neq i} \frac{S_i + S_j}{M_{ij}}, $$
where:
- $ S_i $ is the average distance between the points in cluster $ i $ and the centroid of $ i $,
- $ M_{ij} $ is the distance between the centroids of clusters $ i $ and $ j $.
Lower values of the Davies-Bouldin Index indicate better clustering quality, as they reflect smaller within-cluster dispersion relative to the separation between clusters.
Calinski-Harabasz Score¶
The Calinski-Harabasz Score (CH) (also known as the Variance Ratio Criterion) is given by
$$ \text{CH} = \frac{n-k}{k-1}\cdot\frac{\sum _{i=1}^{k}n_{i}||\mathbf {c} _{i}-\mathbf {c} ||^{2}}{\sum _{i=1}^{k}\sum _{\mathbf {x} \in C_{i}}||\mathbf {x} -\mathbf {c} _{i}||^{2}}, $$
where:
- $ k $ is the number of clusters,
- $ n $ is the total number of samples.
Higher Calinski-Harabasz scores suggest a model in which clusters are dense and well separated.
Dunn Index¶
The Dunn Index seeks to identify clusters that are both compact and well separated. It is defined as
$$ D = \frac{\min\limits_{i \neq j} d(C_i, C_j)}{\max\limits_{k} \mathrm{diam}(C_k)}, $$
where:
- $ d(C_i, C_j) $ is the minimum distance between any two points in clusters $ C_i $ and $ C_j $,
- $ \mathrm{diam}(C_k) $ is the maximum distance between any two points in cluster $ C_k $.
A higher Dunn Index indicates better clustering, meaning that the clusters are more compact and well separated.
Bayesian Information Criterion (BIC)¶
The Bayesian Information Criterion (BIC) is used for model selection by penalizing model complexity while rewarding goodness-of-fit. It is computed as
$$ \text{BIC} = n_{\text{params}} \cdot \ln(N) - 2 \cdot \mathcal{L}, $$
where:
- $ n_{\text{params}} $ is the number of free parameters in the model,
- $ N $ is the number of samples,
- $ \mathcal{L} $ is the log-likelihood of the model.
Lower BIC values indicate a model that better balances fit and simplicity.
Akaike Information Criterion (AIC)¶
The Akaike Information Criterion (AIC) is another metric for model selection defined as
$$ \text{AIC} = 2 \cdot n_{\text{params}} - 2 \cdot \mathcal{L}. $$
As with BIC, lower AIC values suggest a model that achieves a good trade-off between complexity and fit quality.
Note: In practice, the ideal number of clusters is typically determined by seeking a maximum in the Silhouette Score, Calinski-Harabasz Score, and Dunn Index, while simultaneously looking for minima in the Davies-Bouldin Index, AIC, and BIC.
components_range = np.arange(2, 11)
silhouette_vals = torch.zeros(len(components_range), device=device)
davies_vals = torch.zeros(len(components_range), device=device)
calinski_vals = torch.zeros(len(components_range), device=device)
dunn_vals = torch.zeros(len(components_range), device=device)
bic_vals = torch.zeros(len(components_range), device=device)
aic_vals = torch.zeros(len(components_range), device=device)
# Fit a GMM for each n in components_range
for i, n in tqdm(enumerate(components_range), total=len(components_range), desc="Evaluating range"):
gmm = GaussianMixture(
n_features=n_features,
n_components=n,
covariance_type='full',
max_iter=1000,
init_params='kmeans',
device=device,
random_state=random_state,
tol=1e-5,
reg_covar=1e-7
)
# Fit
gmm.fit(X_tensor)
labels_pred = gmm.predict(X_tensor) # shape (N,)
# Compute unsupervised metrics
silhouette_vals[i] = ClusteringMetrics.silhouette_score(X_tensor, labels_pred, n_components=n)
davies_vals[i] = ClusteringMetrics.davies_bouldin_index(X_tensor, labels_pred, n_components=n)
calinski_vals[i] = ClusteringMetrics.calinski_harabasz_score(X_tensor, labels_pred, n_components=n)
dunn_vals[i] = ClusteringMetrics.dunn_index(X_tensor, labels_pred, n_components=n)
bic_vals[i] = ClusteringMetrics.bic_score(gmm.score(X_tensor), X_tensor, n, gmm.covariance_type)
aic_vals[i] = ClusteringMetrics.aic_score(gmm.score(X_tensor), X_tensor, n, gmm.covariance_type)
# Compute the predicted ideal number of components for each metric
sil_best = components_range[torch.argmax(silhouette_vals)].item()
davies_best = components_range[torch.argmin(davies_vals)].item()
calinski_best = components_range[torch.argmax(calinski_vals)].item()
dunn_best = components_range[torch.argmax(dunn_vals)].item()
bic_best = components_range[torch.argmin(bic_vals)].item()
aic_best = components_range[torch.argmin(aic_vals)].item()
# Create a 3-rows x 2-columns figure (3 rows, 2 columns)
nrows, ncols = 3, 2
fig, axs = plt.subplots(nrows, ncols, figsize=dynamic_figsize(nrows, ncols))
# 1) Silhouette Score (Row 1, Col 1)
axs[0, 0].plot(components_range, silhouette_vals.cpu(), 'o-b')
axs[0, 0].axvline(x=n_components, color='r', linestyle='--', label='True # of Components')
axs[0, 0].axvline(x=sil_best, color='k', linestyle=':', label='Predicted Ideal # of Components')
axs[0, 0].set_title('Silhouette Score')
axs[0, 0].set_xlabel('Number of Components')
axs[0, 0].set_ylabel('Score')
axs[0, 0].grid(True)
axs[0, 0].legend(loc='upper right')
# 2) Davies-Bouldin Index (Row 1, Col 2)
axs[0, 1].plot(components_range, davies_vals.cpu(), 'o-g')
axs[0, 1].axvline(x=n_components, color='r', linestyle='--', label='True # of Components')
axs[0, 1].axvline(x=davies_best, color='k', linestyle=':', label='Predicted Ideal # of Components')
axs[0, 1].set_title('Davies-Bouldin Index')
axs[0, 1].set_xlabel('Number of Components')
axs[0, 1].grid(True)
axs[0, 1].legend(loc='upper right')
# 3) Calinski-Harabasz Score (Row 2, Col 1)
axs[1, 0].plot(components_range, calinski_vals.cpu(), 'o-y')
axs[1, 0].axvline(x=n_components, color='r', linestyle='--', label='True # of Components')
axs[1, 0].axvline(x=calinski_best, color='k', linestyle=':', label='Predicted Ideal # of Components')
axs[1, 0].set_title('Calinski-Harabasz Score')
axs[1, 0].set_xlabel('Number of Components')
axs[1, 0].grid(True)
axs[1, 0].legend(loc='upper right')
# 4) Dunn Index (Row 2, Col 2)
axs[1, 1].plot(components_range, dunn_vals.cpu(), 'o-', color='orange')
axs[1, 1].axvline(x=n_components, color='r', linestyle='--', label='True # of Components')
axs[1, 1].axvline(x=dunn_best, color='k', linestyle=':', label='Predicted Ideal # of Components')
axs[1, 1].set_title('Dunn Index')
axs[1, 1].set_xlabel('Number of Components')
axs[1, 1].grid(True)
axs[1, 1].legend(loc='upper right')
# 5) BIC Score (Row 3, Col 1)
axs[2, 0].plot(components_range, bic_vals.cpu(), 'o-m')
axs[2, 0].axvline(x=n_components, color='r', linestyle='--', label='True # of Components')
axs[2, 0].axvline(x=bic_best, color='k', linestyle=':', label='Predicted Ideal # of Components')
axs[2, 0].set_title('BIC Score')
axs[2, 0].set_xlabel('Number of Components')
axs[2, 0].grid(True)
axs[2, 0].legend(loc='upper right')
# 6) AIC Score (Row 3, Col 2)
axs[2, 1].plot(components_range, aic_vals.cpu(), 'o-c')
axs[2, 1].axvline(x=n_components, color='r', linestyle='--', label='True # of Components')
axs[2, 1].axvline(x=aic_best, color='k', linestyle=':', label='Predicted Ideal # of Components')
axs[2, 1].set_title('AIC Score')
axs[2, 1].set_xlabel('Number of Components')
axs[2, 1].grid(True)
axs[2, 1].legend(loc='upper right')
plt.suptitle("Unsupervised Clustering Metrics for a GMM")
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()
print("=== Best Number of Components According to Each Metric ===")
print(f"Silhouette Best: {sil_best}")
print(f"Davies-Bouldin Best (lowest): {davies_best}")
print(f"Calinski-Harabasz Best: {calinski_best}")
print(f"Dunn Index Best: {dunn_best}")
print(f"BIC Best (lowest): {bic_best}")
print(f"AIC Best (lowest): {aic_best}")
Evaluating range: 100%|██████████| 9/9 [00:00<00:00, 12.44it/s]
=== Best Number of Components According to Each Metric === Silhouette Best: 3 Davies-Bouldin Best (lowest): 4 Calinski-Harabasz Best: 3 Dunn Index Best: 5 BIC Best (lowest): 4 AIC Best (lowest): 4
Supervised Clustering Metrics Definitions¶
When ground truth labels are available, we can evaluate clustering performance using supervised metrics that measure the agreement between predicted and true cluster assignments.
Rand Index (RI)¶
The Rand Index measures the similarity between two clusterings by considering all pairs of samples and counting pairs that are assigned to the same or different clusters in both clusterings.
$$ \text{RI} = \frac{TP + TN}{TP + TN + FP + FN} $$
where:
- $TP$ (True Positives): pairs that are in the same cluster in both clusterings
- $TN$ (True Negatives): pairs that are in different clusters in both clusterings
- $FP$ (False Positives): pairs that are in the same cluster in predicted but different clusters in true
- $FN$ (False Negatives): pairs that are in different clusters in predicted but same cluster in true
The Rand Index ranges from 0 to 1, where 1 indicates perfect agreement.
Adjusted Rand Index (ARI)¶
The Adjusted Rand Index corrects the Rand Index for chance by subtracting the expected value and normalizing by the maximum possible value:
$$ \text{ARI} = \frac{\text{RI} - \mathbb{E}[\text{RI}]}{\max(\text{RI}) - \mathbb{E}[\text{RI}]} $$
More explicitly, using the contingency table approach:
$$ \text{ARI} = \frac{\sum_{ij} \binom{n_{ij}}{2} - \left[\sum_i \binom{a_i}{2} \sum_j \binom{b_j}{2}\right] / \binom{n}{2}}{\frac{1}{2}\left[\sum_i \binom{a_i}{2} + \sum_j \binom{b_j}{2}\right] - \left[\sum_i \binom{a_i}{2} \sum_j \binom{b_j}{2}\right] / \binom{n}{2}} $$
where $n_{ij}$ is the number of samples in cluster $i$ of the true clustering and cluster $j$ of the predicted clustering. ARI ranges from -1 to 1, with 1 indicating perfect agreement and 0 indicating random labeling.
Mutual Information (MI)¶
Mutual Information measures the amount of information obtained about one clustering by observing the other:
$$ \text{MI}(U,V) = \sum_{i=1}^{|U|} \sum_{j=1}^{|V|} P(i,j) \log\frac{P(i,j)}{P(i)P(j)} $$
where:
- $U$ and $V$ are the true and predicted clusterings
- $P(i,j) = \frac{|U_i \cap V_j|}{N}$ is the probability that a point belongs to clusters $U_i$ and $V_j$
- $P(i) = \frac{|U_i|}{N}$ and $P(j) = \frac{|V_j|}{N}$ are marginal probabilities
Normalized Mutual Information (NMI)¶
Normalized Mutual Information scales MI to the range [0,1] by normalizing with the entropy of the clusterings:
$$ \text{NMI}(U,V) = \frac{2 \times \text{MI}(U,V)}{H(U) + H(V)} $$
where $H(U) = -\sum_{i=1}^{|U|} P(i) \log P(i)$ is the entropy of clustering $U$.
Adjusted Mutual Information (AMI)¶
Adjusted Mutual Information corrects MI for chance, similar to how ARI corrects RI:
$$ \text{AMI}(U,V) = \frac{\text{MI}(U,V) - \mathbb{E}[\text{MI}(U,V)]}{\max(H(U), H(V)) - \mathbb{E}[\text{MI}(U,V)]} $$
Fowlkes-Mallows Index (FMI)¶
The Fowlkes-Mallows Index is the geometric mean of pairwise precision and recall:
$$ \text{FMI} = \sqrt{\frac{TP}{TP + FP} \times \frac{TP}{TP + FN}} = \sqrt{\text{Precision} \times \text{Recall}} $$
where TP, FP, and FN are defined as in the Rand Index but for pairwise comparisons.
Homogeneity and Completeness¶
Homogeneity measures whether each cluster contains only members of a single class:
$$ h = 1 - \frac{H(C|K)}{H(C)} $$
Completeness measures whether all members of a given class are assigned to the same cluster:
$$ c = 1 - \frac{H(K|C)}{H(K)} $$
where $H(C|K)$ is the conditional entropy of the true classes given the cluster assignments.
V-Measure¶
The V-Measure is the harmonic mean of homogeneity and completeness:
$$ \text{V} = \frac{2 \times h \times c}{h + c} $$
Purity¶
Purity measures the extent to which clusters contain a single class:
$$ \text{Purity} = \frac{1}{N} \sum_{k=1}^{K} \max_j |C_k \cap T_j| $$
where $C_k$ is the set of samples in cluster $k$ and $T_j$ is the set of samples in true class $j$.
Interpretation Guidelines:
- Higher is better: RI, ARI, MI, NMI, AMI, FMI, Homogeneity, Completeness, V-Measure, Purity
- Range [0,1]: Most metrics except ARI which can be negative
- Perfect clustering: All metrics = 1 (except ARI where perfect = 1, random ≈ 0)
# Scikit-learn for comparison
from sklearn.mixture import GaussianMixture as SklearnGMM
from sklearn.decomposition import PCA
from sklearn.metrics import (
silhouette_score,
davies_bouldin_score,
calinski_harabasz_score,
adjusted_rand_score,
normalized_mutual_info_score,
fowlkes_mallows_score,
homogeneity_score,
mutual_info_score,
adjusted_mutual_info_score,
completeness_score,
v_measure_score,
rand_score
)
gmm = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=1000,
init_params='kmeans',
device=device
)
gmm.fit(X_tensor)
labels_pred = gmm.predict(X_tensor)
# Metrics with custom ClusteringMetrics
metrics_to_compare = [
"rand_score",
"adjusted_rand_score",
"mutual_info_score",
"normalized_mutual_info_score",
"adjusted_mutual_info_score",
"fowlkes_mallows_score",
"homogeneity_score",
"completeness_score",
"v_measure_score",
"silhouette_score",
"davies_bouldin_index",
"calinski_harabasz_score",
"purity_score", # Torch-only metric
"dunn_index", # Torch-only metric
"bic_score", # Compare Torch vs Sklearn BIC
"aic_score", # Compare Torch vs Sklearn AIC
]
scores_torch = ClusteringMetrics.evaluate_clustering(
gmm,
X_tensor,
true_labels=y_tensor,
metrics=metrics_to_compare
)
# Compare with sklearn
X_np = X_tensor.cpu().numpy()
y_np = y_tensor.cpu().numpy()
labels_pred_np = labels_pred.cpu().numpy()
sk_gmm = SklearnGMM(n_components=n_components, covariance_type='full', max_iter=1000, init_params='kmeans', random_state=random_state, tol=1e-5, reg_covar=1e-7)
sk_gmm.fit(X_np)
scores_sklearn = {
"rand_score": rand_score(y_np, labels_pred_np),
"adjusted_rand_score": adjusted_rand_score(y_np, labels_pred_np),
"mutual_info_score": mutual_info_score(y_np, labels_pred_np),
"normalized_mutual_info_score": normalized_mutual_info_score(y_np, labels_pred_np),
"adjusted_mutual_info_score": adjusted_mutual_info_score(y_np, labels_pred_np),
"fowlkes_mallows_score": fowlkes_mallows_score(y_np, labels_pred_np),
"homogeneity_score": homogeneity_score(y_np, labels_pred_np),
"completeness_score": completeness_score(y_np, labels_pred_np),
"v_measure_score": v_measure_score(y_np, labels_pred_np),
"silhouette_score": silhouette_score(X_np, labels_pred_np),
"davies_bouldin_index": davies_bouldin_score(X_np, labels_pred_np),
"calinski_harabasz_score": calinski_harabasz_score(X_np, labels_pred_np),
}
sk_bic = sk_gmm.bic(X_np)
sk_aic = sk_gmm.aic(X_np)
rows = []
for metric in metrics_to_compare:
# Retrieve the Torch metric score if available.
torch_val = scores_torch.get(metric, None)
# Retrieve the scikit-learn metric score if available.
if metric in scores_sklearn:
sklearn_val = scores_sklearn.get(metric)
elif metric == "bic_score":
sklearn_val = sk_bic
elif metric == "aic_score":
sklearn_val = sk_aic
else:
sklearn_val = None
# Calculate the absolute difference and the relative difference in percent.
if torch_val is not None and sklearn_val is not None:
abs_diff = abs(torch_val - sklearn_val)
# Avoid division by zero; if sklearn_val is zero, set relative difference to None.
rel_diff = (abs_diff / abs(sklearn_val)) * 100 if sklearn_val != 0 else None
else:
abs_diff = None
rel_diff = None
rows.append({
"Metric": metric,
"Torch Score": torch_val,
"Sklearn Score": sklearn_val,
"Absolute Difference": abs_diff,
"Relative Difference (%)": rel_diff
})
fig, ax = plt.subplots()
plot_gmm(X=X_np, gmm=gmm, true_labels=y_tensor, match_labels_to_true=True, ax=ax, title='GMM Predictions', show_ellipses=True, show_incorrect_predictions=True, ellipse_fill=False, ellipse_std_devs=[3], point_size=10)
plt.show()
Model Comparison: TorchGMM vs Scikit-learn¶
This section compares the clustering metrics computed by our TorchGMM implementation with scikit-learn's implementations to validate correctness and highlight any differences in computation methods.
Metrics Comparison Table¶
The table below shows side-by-side comparisons of metrics computed using both implementations. Small differences may occur due to:
- Different numerical precision
- Slightly different algorithmic implementations
- Different handling of edge cases (e.g., single-point clusters)
Key observations:
- Information-theoretic metrics (MI, NMI, AMI) should be very close
- Pairwise metrics (RI, ARI, FMI) should match exactly for identical cluster assignments
- Geometric metrics (Silhouette, Davies-Bouldin, Calinski-Harabasz) may show small differences due to distance computation methods
df_metrics = pd.DataFrame(rows)
df_metrics
| Metric | Torch Score | Sklearn Score | Absolute Difference | Relative Difference (%) | |
|---|---|---|---|---|---|
| 0 | rand_score | 0.922936 | 0.922936 | 7.921422e-09 | 8.582850e-07 |
| 1 | adjusted_rand_score | 0.816571 | 0.816571 | 9.265805e-09 | 1.134721e-06 |
| 2 | mutual_info_score | 0.980560 | 0.980560 | 2.375290e-08 | 2.422381e-06 |
| 3 | normalized_mutual_info_score | 0.778043 | 0.778043 | 1.091963e-07 | 1.403475e-05 |
| 4 | adjusted_mutual_info_score | 0.777949 | 0.777777 | 1.721395e-04 | 2.213223e-02 |
| 5 | fowlkes_mallows_score | 0.871655 | 0.871655 | 3.285231e-08 | 3.768957e-06 |
| 6 | homogeneity_score | 0.773056 | 0.774893 | 1.837321e-03 | 2.371064e-01 |
| 7 | completeness_score | 0.782989 | 0.781218 | 1.771340e-03 | 2.267408e-01 |
| 8 | v_measure_score | 0.777991 | 0.778043 | 5.184373e-05 | 6.663353e-03 |
| 9 | silhouette_score | 0.184356 | 0.184356 | 4.470348e-08 | 2.424840e-05 |
| 10 | davies_bouldin_index | 1.252130 | 1.252130 | 1.186502e-07 | 9.475868e-06 |
| 11 | calinski_harabasz_score | 701.308289 | 701.308533 | 2.441406e-04 | 3.481216e-05 |
| 12 | purity_score | 0.931000 | NaN | NaN | NaN |
| 13 | dunn_index | 0.002177 | NaN | NaN | NaN |
| 14 | bic_score | 19292.576172 | 19292.308851 | 2.673210e-01 | 1.385635e-03 |
| 15 | aic_score | 19154.428955 | 19154.162397 | 2.665582e-01 | 1.391646e-03 |
Confusion Matrix and Classification Report¶
Understanding the Confusion Matrix¶
The confusion matrix provides a detailed breakdown of correct and incorrect predictions for each class. Before computing the matrix, we use label matching to align predicted cluster labels with true class labels, as clustering algorithms may assign arbitrary label numbers.
Key metrics derived from the confusion matrix:
- Precision: $\frac{TP}{TP + FP}$ - What fraction of predicted positives are actually positive?
- Recall: $\frac{TP}{TP + FN}$ - What fraction of actual positives are correctly predicted?
- F1-Score: $\frac{2 \times \text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}}$ - Harmonic mean of precision and recall
- Support: Number of true instances for each class
The heatmap below shows the confusion matrix as percentages, normalized by the true class sizes.
# Get matched predictions
matched_pred = match_predicted_to_true_labels(y_tensor, labels_pred)
# Compute confusion matrix using ClusteringMetrics.confusion_matrix
cm = ClusteringMetrics.confusion_matrix(
y_tensor.clone().detach().to(dtype=torch.long),
matched_pred.clone().detach().to(dtype=torch.long)
)
# Convert the confusion matrix to percentages (normalize per true label)
cm_np = cm.numpy().astype(float)
cm_percent = (cm_np.T / cm_np.sum(axis=1)).T * 100
# Plot heatmap of the percent confusion matrix
plt.figure()
sns.heatmap(cm_percent, annot=True, fmt=".1f", cmap="Blues")
plt.title("Confusion Matrix (%)")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.show()
# Generate classification report
report = ClusteringMetrics.classification_report(
y_tensor.clone().detach().to(dtype=torch.long),
matched_pred.clone().detach().to(dtype=torch.long)
)
# Convert classification report to a DataFrame for display
report_df = pd.DataFrame(report).T
report_df
| precision | recall | f1-score | support | jaccard | roc_auc | |
|---|---|---|---|---|---|---|
| 0 | 0.951923 | 0.86625 | 0.907068 | 800.0 | 0.829940 | 0.997872 |
| 1 | 0.969072 | 0.94000 | 0.954315 | 200.0 | 0.912621 | 0.982729 |
| 2 | 0.898058 | 0.92500 | 0.911330 | 1000.0 | 0.837104 | 0.917763 |
| 3 | 0.941794 | 0.98700 | 0.963867 | 1000.0 | 0.930254 | 0.956896 |
KL Divergence Analysis¶
Kullback-Leibler Divergence Between GMMs¶
The KL divergence $D_{KL}(P \parallel Q)$ measures how one probability distribution $P$ diverges from a reference distribution $Q$:
$$ D_{KL}(P \parallel Q) = \int p(x) \log \frac{p(x)}{q(x)} dx $$
For Gaussian Mixture Models, we approximate this using Monte Carlo sampling:
$$ D_{KL}(P \parallel Q) \approx \frac{1}{N} \sum_{i=1}^{N} \left[ \log p(x_i) - \log q(x_i) \right] $$
where $x_i \sim P$ are samples drawn from the first GMM.
Interpretation:
- $D_{KL}(P \parallel Q) = 0$ when $P = Q$ (identical distributions)
- $D_{KL}(P \parallel Q) > 0$ always (non-negative)
- Asymmetric: $D_{KL}(P \parallel Q) \neq D_{KL}(Q \parallel P)$ in general
- Lower values indicate more similar distributions
The plot below shows how KL divergence changes as we vary the number of components in the test GMM while keeping the reference GMM fixed at the true number of components.
RUN_KL_DIVERGENCE = True
if RUN_KL_DIVERGENCE:
print("Computing KL(p||q) for different numbers of components...")
gmm_true = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=1000,
init_params='kmeans',
device=device
)
gmm_true.fit(X_tensor)
test_range = np.arange(1, 15)
kl_vals = torch.zeros(len(test_range), device=device)
for i, n in tqdm(enumerate(test_range), total=len(test_range)):
gmm_test = GaussianMixture(
n_features=n_features,
n_components=n,
covariance_type='full',
max_iter=1000,
init_params='kmeans',
device=device
)
gmm_test.fit(X_tensor)
kl_vals[i] = ClusteringMetrics.kl_divergence_gmm(gmm_true, gmm_test, n_samples=10000)
plt.figure()
plt.plot(test_range, kl_vals.cpu().numpy(), marker='o')
plt.yscale('log')
plt.title("KL Divergence: True GMM vs. Various n_components")
plt.xlabel("Number of Components")
plt.ylabel("KL Divergence (log scale)")
plt.grid(True)
plt.show()
Computing KL(p||q) for different numbers of components...
100%|██████████| 14/14 [00:01<00:00, 12.63it/s]