Plotting and Visualization¶
The plotting module provides tools for visualizing GMM clustering results, including PCA-based projections for high-dimensional data.
Overview¶
Visualization functions include:
- 2D/3D scatter plots with cluster assignments
- PCA projections for high-dimensional data
- Confidence ellipses showing cluster distributions
- Component overlays visualizing individual Gaussians
- Interactive plots using Plotly
Basic 2D Plotting¶
from tgmm import GaussianMixture
from tgmm.plotting import plot_gmm_2d
import torch
# Fit GMM
gmm = GaussianMixture(n_components=3, n_features=2)
gmm.fit(X)
# Plot results
plot_gmm_2d(
X,
gmm,
show_ellipses=True,
show_centers=True,
title='GMM Clustering Results'
)
PCA Projection for High-Dimensional Data¶
When data has more than 2-3 dimensions, use PCA to project to 2D:
from tgmm.plotting import plot_gmm_pca
# X has shape (n_samples, n_features) where n_features > 3
gmm = GaussianMixture(n_components=5, n_features=X.shape[1])
gmm.fit(X)
# Project to 2D using PCA and plot
plot_gmm_pca(
X,
gmm,
n_components=2, # Project to 2D
show_variance=True, # Show explained variance
show_ellipses=True
)
3D Visualization¶
from tgmm.plotting import plot_gmm_3d
# For 3D data
gmm = GaussianMixture(n_components=4, n_features=3)
gmm.fit(X_3d)
plot_gmm_3d(
X_3d,
gmm,
show_ellipses=True,
interactive=True # Use Plotly for rotation
)
Confidence Ellipses¶
Visualize cluster distributions with confidence ellipses:
plot_gmm_2d(
X,
gmm,
show_ellipses=True,
confidence=0.95, # 95% confidence ellipse
alpha=0.3 # Transparency
)
Component Visualization¶
Show individual Gaussian components:
from tgmm.plotting import plot_components
plot_components(
gmm,
component_ids=[0, 1, 2], # Which components to show
show_means=True,
show_covariances=True
)
Convergence Monitoring¶
Track convergence during fitting:
from tgmm.plotting import plot_convergence
gmm = GaussianMixture(n_components=3, n_features=2, verbose=True)
gmm.fit(X)
# Plot log-likelihood over iterations
plot_convergence(gmm.lower_bound_history_)
Customization¶
Color Maps¶
Figure Size and DPI¶
Custom Styling¶
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 8))
plot_gmm_2d(
X,
gmm,
ax=ax, # Use custom axes
show_ellipses=True,
ellipse_color='red',
center_marker='X',
center_size=200
)
ax.set_xlabel('Feature 1', fontsize=14)
ax.set_ylabel('Feature 2', fontsize=14)
plt.tight_layout()
plt.show()
Advanced Examples¶
Side-by-Side Comparison¶
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, cov_type in enumerate(['full', 'diag', 'spherical']):
gmm = GaussianMixture(
n_components=3,
n_features=2,
covariance_type=cov_type
)
gmm.fit(X)
plot_gmm_2d(
X,
gmm,
ax=axes[i],
show_ellipses=True,
title=f'{cov_type.capitalize()} Covariance'
)
plt.tight_layout()
plt.show()
Animation of EM Steps¶
from tgmm.plotting import animate_em_steps
gmm = GaussianMixture(n_components=3, n_features=2, max_iter=20)
gmm.fit(X, track_steps=True) # Enable step tracking
# Create animation showing EM iterations
animate_em_steps(
X,
gmm,
save_path='em_animation.gif',
fps=2
)
Probability Contours¶
from tgmm.plotting import plot_probability_contours
plot_probability_contours(
X,
gmm,
n_levels=10,
show_data=True,
log_scale=True # Use log probability
)
PCA Visualization Details¶
When working with high-dimensional data:
from tgmm.plotting import plot_gmm_pca
import numpy as np
# High-dimensional data
X_hd = torch.randn(1000, 50) # 50 dimensions
gmm = GaussianMixture(n_components=5, n_features=50)
gmm.fit(X_hd)
# Project to 2D and visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
# Plot with cluster coloring
plot_gmm_pca(
X_hd,
gmm,
n_components=2,
ax=ax1,
show_variance=True,
title='PCA Projection with Clusters'
)
# Plot showing component densities
plot_gmm_pca(
X_hd,
gmm,
n_components=2,
ax=ax2,
show_density=True,
title='PCA Projection with Density'
)
plt.tight_layout()
plt.show()
Explained Variance¶
from sklearn.decomposition import PCA
pca = PCA(n_components=10)
pca.fit(X_hd.numpy())
# Plot variance explained
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(pca.explained_variance_ratio_, 'o-')
plt.xlabel('Component')
plt.ylabel('Explained Variance Ratio')
plt.title('Scree Plot')
plt.subplot(1, 2, 2)
plt.plot(np.cumsum(pca.explained_variance_ratio_), 'o-')
plt.xlabel('Component')
plt.ylabel('Cumulative Explained Variance')
plt.title('Cumulative Variance')
plt.tight_layout()
plt.show()
Interactive Plots with Plotly¶
from tgmm.plotting import plot_gmm_interactive
# Create interactive 3D plot
plot_gmm_interactive(
X_3d,
gmm,
title='Interactive GMM Visualization',
show_ellipses=True,
save_html='gmm_plot.html'
)
Tips for Effective Visualization¶
- Choose appropriate projections: Use PCA for high-dimensional data
- Adjust transparency: Use
alphaparameter to see overlapping clusters - Scale data: Standardize features for better ellipse visualization
- Use confidence levels: Show different probability contours (e.g., 0.68, 0.95, 0.99)
- Compare models: Plot multiple covariance types side-by-side
- Monitor convergence: Plot log-likelihood to ensure proper fitting
Complete API Reference¶
For full details on all plotting functions, see the API Reference.