Priors in Gaussian Mixture Models¶
In this notebook, we explore how to incorporate priors into a Gaussian Mixture Model (GMM) using the tgmm package. Priors help regularize model parameters and can steer the Expectation-Maximization (EM) algorithm toward more robust, stable solutions.
We will illustrate the use of three types of priors:
- Weight Priors: Dirichlet priors on the mixture component weights.
- Mean Priors: Gaussian priors on the component means.
- Covariance Priors: Inverse-Wishart (or related) priors on the covariance matrices.
Let’s begin by setting up our environment and generating some synthetic data.
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import scipy.special
from scipy.stats import wishart, invgamma
from scipy.ndimage import gaussian_filter
os.chdir('../')
from tgmm import GaussianMixture, GMMInitializer, dynamic_figsize, plot_gmm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
random_state = 42
np.random.seed(random_state)
torch.manual_seed(random_state)
if device == 'cuda':
torch.cuda.manual_seed(random_state)
print('CUDA version:', torch.version.cuda)
print('Device:', torch.cuda.get_device_name(0))
else:
print('Using CPU')
CUDA version: 12.9 Device: NVIDIA GeForce RTX 4060 Laptop GPU
Synthetic Data Generation¶
The synthetic dataset is generated by combining four Gaussian components:
- Component 1: Centered at
[0, 2]with spherical covariance. - Component 2: Centered at
[2, -2]with spherical covariance (fewer points). - Component 3: Centered at
[0, 0]with diagonal covariance. - Component 4: Centered at
[2, 2]with full covariance.
n_samples = [800, 200, 1000, 1000]
centers = [np.array([0, 2]),
np.array([2, -2]),
np.array([0, 0]),
np.array([2, 2])]
covs = [
1.0 * np.eye(2), # spherical covariance
0.5 * np.eye(2), # spherical covariance, fewer points
np.array([[2, 0], [0, 0.5]]), # diagonal covariance
np.array([[0.2, 0.5], [0.5, 2]]) # full covariance
]
components = []
for n, center, cov in zip(n_samples, centers, covs):
samples = np.dot(np.random.randn(n, 2), cov) + center
components.append(samples)
X = np.vstack(components)
labels = np.concatenate([i * np.ones(n) for i, n in enumerate(n_samples)])
legend_labels = [f'Component {i+1}' for i in range(len(n_samples))]
X_tensor = torch.tensor(X, dtype=torch.float32, device=device)
n_features = X.shape[1]
n_components = len(n_samples)
plot_gmm(X=X, true_labels=labels, title='Original Data', legend_labels=legend_labels)
plt.show()
Dirichlet Distribution: Definition and Derivation from the Beta Distribution¶
The Dirichlet distribution is a multivariate generalization of the Beta distribution. For a vector $$ \mathbf{x} = (x_1, x_2, \dots, x_K) $$ that lies in the $(K-1)$-simplex, i.e., $$ \sum_{i=1}^K x_i = 1 \quad \text{and} \quad x_i \ge 0 \quad \text{for all } i, $$ the Dirichlet distribution with parameter vector $$ \boldsymbol{\alpha} = (\alpha_1, \alpha_2, \dots, \alpha_K) \quad \text{with} \quad \alpha_i > 0, $$ is defined by its probability density function $$ f(x_1, \dots, x_K; \boldsymbol{\alpha}) = \frac{1}{B(\boldsymbol{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}, $$ where the normalizing constant $B(\boldsymbol{\alpha})$ is the multivariate Beta function: $$ B(\boldsymbol{\alpha}) = \frac{\prod_{i=1}^K \Gamma(\alpha_i)}{\Gamma\left(\sum_{i=1}^K \alpha_i\right)}. $$
Derivation from the Beta Distribution¶
For the special case $K=2$, let $$ x_1 = x \quad \text{and} \quad x_2 = 1 - x. $$ Then the Dirichlet density becomes $$ f(x; \alpha_1, \alpha_2) = \frac{1}{B(\alpha_1, \alpha_2)} x^{\alpha_1 - 1} (1 - x)^{\alpha_2 - 1}, $$ with $$ B(\alpha_1, \alpha_2) = \frac{\Gamma(\alpha_1)\Gamma(\alpha_2)}{\Gamma(\alpha_1 + \alpha_2)}. $$ This is exactly the density of the Beta distribution, denoted as $$ \operatorname{Beta}(\alpha_1,\alpha_2). $$
Symmetric Dirichlet Distribution¶
In its most general form, the Dirichlet distribution has a parameter vector $\boldsymbol{\alpha} = (\alpha_1, \alpha_2, \dots, \alpha_K)$, with each $\alpha_i$ playing a role similar to that in the Beta distribution. However, in many practical applications—especially when there is no prior reason to favor one category over another—a symmetric Dirichlet distribution is used. In the symmetric case, all parameters are set equal, i.e., $$ \alpha_1 = \alpha_2 = \cdots = \alpha_K = \alpha. $$ This reduces the number of free parameters from $K$ to one, and the density function simplifies to $$ f(x_1,\dots,x_K; \alpha) = \frac{\Gamma(K\alpha)}{[\Gamma(\alpha)]^K} \prod_{i=1}^K x_i^{\alpha - 1}. $$ Thus, while the general Dirichlet has $K$ parameters, the symmetric case is completely described by a single concentration parameter $\alpha$.
How Does $\alpha$ Affect the Dirichlet Distribution?¶
Small $\alpha$ ($\alpha \ll 1$):
When $\alpha$ is much less than 1, the distribution is sparse, favoring configurations where one or a few weights are near 1 while the others are close to 0. This tends to lead to “hard” clustering where only a few components dominate.Uniform Case ($\alpha = 1$):
For $\alpha = 1$, the Dirichlet distribution is uniform over the simplex, meaning that all weight configurations are equally likely a priori.Large $\alpha$ ($\alpha \gg 1$):
With a large $\alpha$, the distribution becomes more concentrated around the center of the simplex, promoting balanced weights across all components.
The plots below visualize how different $\alpha$ values affect the Dirichlet density over a ternary plot.
try:
import mpltern
except ImportError:
!pip install mpltern
import mpltern
def dirichlet_pdf(x, y, z, alpha):
"""
Compute the Dirichlet density for point (x, y, z)
given parameters alpha = (α1, α2, α3).
"""
a1, a2, a3 = alpha
B = (scipy.special.gamma(a1) * scipy.special.gamma(a2) * scipy.special.gamma(a3) /
scipy.special.gamma(a1 + a2 + a3))
return (x**(a1 - 1) * y**(a2 - 1) * z**(a3 - 1)) / B
alpha_list = [
((2, 2, 2), "Dirichlet Density (α=2,2,2)"),
((8, 8, 8), "Dirichlet Density (α=8,8,8)"),
((1, 1, 2), "Dirichlet Density (α=1,1,2)"),
((1, 2, 2), "Dirichlet Density (α=1,2,2)"),
((2, 4, 8), "Dirichlet Density (α=2,4,8)"),
((1, 1, 1), "Dirichlet Density (α=1,1,1)"),
]
N = 200
a_vals = np.linspace(0.0001, 0.9999, N)
b_vals = np.linspace(0.0001, 0.9999, N)
A, B_mesh = np.meshgrid(a_vals, b_vals)
C = 1 - A - B_mesh
mask = C > 0
x_grid = A[mask]
y_grid = B_mesh[mask]
z_grid = C[mask]
nrows, ncols = 3, 2
fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
subplot_kw={'projection': 'ternary'},
figsize=dynamic_figsize(nrows, ncols))
for ax, (alpha, title) in zip(axes.flat, alpha_list):
densities = dirichlet_pdf(x_grid, y_grid, z_grid, alpha)
# Draw the simplex boundary.
vertices = [(0, 0, 1), (0, 1, 0), (1, 0, 0), (0, 0, 1)]
ax.plot([v[0] for v in vertices],
[v[1] for v in vertices],
[v[2] for v in vertices],
color='k', lw=2)
# Plot filled contours using the clipped densities.
triang = ax.tricontourf(x_grid, y_grid, z_grid, densities, levels=15, cmap='viridis')
ax.tricontour(x_grid, y_grid, z_grid, densities, levels=triang.levels, colors='k', linewidths=0.5)
ax.grid(False)
ax.set_title(title)
plt.suptitle('Different Dirichlet Distributions', y=1.05)
plt.tight_layout()
plt.show()
Weight Priors in GMM: Small vs. Large $\alpha$¶
Weight priors directly influence the mixing proportions of the GMM. Here, we compare the effects of setting a very small versus a very large $\alpha$ value:
- Small $\alpha$: Leads to a model where few components dominate.
- Large $\alpha$: Encourages the model to assign roughly equal weights to all components.
Each subplot shows the fitted GMM components (ellipses and means) along with the log-likelihood (LL) achieved.
weigths = np.array([1e-4, 1e4])
weights = [torch.tensor([weight], device=device) for weight in weigths]
nrows, ncols = 2, 1
figsize = dynamic_figsize(nrows, ncols)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
fig.suptitle("Different Priors: Small vs Large Weight")
for ax, weight in zip(axs, weigths):
gmm_weight_prior = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=1000,
device=device,
weight_concentration_prior=weight,
random_state=random_state,
init_means='random'
)
gmm_weight_prior.fit(X_tensor)
plot_gmm(X=X, gmm=gmm_weight_prior, ax=ax,
title=f"alpha={weight.item()}\nLL={gmm_weight_prior.lower_bound_:.4f}",
scale_alpha_by_weight=True,
ellipse_line_style='dashed',
ellipse_alpha=0.7,
scale_size_by_weight=False,
ellipse_colors='yellow',
ellipse_std_devs=[2],
mean_color='red',
)
plt.tight_layout()
plt.show()
Unbalanced Weight Priors: Dominant Component Effect¶
In this example, we apply unbalanced weight priors by assigning one component a significantly higher concentration parameter than the others. As expected, the dominant component tends to capture a larger share of the data, which is reflected in both the clustering results and the overall model log-likelihood.
# Example: One dominant prior and three nearly zero priors.
# Here, we expect the first component to dominate.
weight1 = torch.tensor([1000.0, 1, 1, 1], device=device)
weight2 = torch.tensor([1000.0, 1000.0, 1, 1], device=device)
weight3 = torch.tensor([1000.0, 1000.0, 1000.0, 1], device=device)
weights = [weight1, weight2, weight3]
nrows, ncols = len(weights), 1
figsize = dynamic_figsize(nrows, ncols)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
fig.suptitle("Different Priors: Unbalanced Weights")
for ax, weight in zip(axs, weights):
gmm_weight_prior = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=1000,
device=device,
weight_concentration_prior=weight,
random_state=random_state,
init_means='random'
)
gmm_weight_prior.fit(X_tensor)
# Plot only ellipses and means, with alpha based on component weights.
plot_gmm(X=X, gmm=gmm_weight_prior, ax=ax,
title=f"alpha={weight.tolist()}\nLL={gmm_weight_prior.lower_bound_:.4f}",
scale_alpha_by_weight=True,
scale_size_by_weight=False,
ellipse_colors='yellow',
ellipse_std_devs=[2],
ellipse_line_style='dashed',
ellipse_alpha=0.7,
mean_color='red'
)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()
Prior on the Means in Gaussian Mixture Models¶
In our Gaussian Mixture Model (GMM), we use a Gaussian prior on the component means to regularize the model and to guide the parameter estimation during Maximum a Posteriori (MAP) inference. This prior prevents the estimated means from drifting too far away from a reasonable central location, which is particularly useful when the data is noisy or scarce.
Gaussian Prior on the Means¶
We assume that each component mean, $\mu_k$, has a Gaussian (normal) prior of the form
$$ p(\mu_k) = \mathcal{N}(\mu_k; \mu_0, \Lambda_0^{-1}), $$
where:
- $\mu_0$ is the prior mean (often set using an initialization procedure such as K-means),
- $\Lambda_0$ is the prior precision matrix (the inverse of the prior covariance).
For simplicity, it is common to assume that the precision matrix is a scalar multiple of the identity matrix:
$$ \Lambda_0 = \lambda I, $$
which implies that the prior treats every feature equally and independently.
Impact of the Mean Precision Prior¶
The influence of the mean prior is governed by the scalar precision parameter $\lambda$ (denoted in our code as mean_precision_prior):
Weak Prior (Small $\lambda$):
When $\lambda$ is small, the prior has little influence on the estimated means. In this regime, the data likelihood dominates, and the means are largely determined by the observed data. This can lead to greater flexibility in the estimated means but may also allow for overfitting, especially with noisy data.Strong Prior (Large $\lambda$):
Conversely, when $\lambda$ is large, the prior is very informative. In this case, the estimated means are pulled more strongly toward the prior mean $\mu_0$. This is particularly useful when you have limited data or when you want to impose prior knowledge about the location of the clusters. However, a very strong prior might oversmooth the model, preventing it from fully adapting to the data.
# Visualize the mean prior distributions as 1D Gaussians
precision_values = [1, 10, 1000]
# Create the initial means using K-means on the CPU
init_means_kmeans = GMMInitializer.kmeans(X_tensor.cpu(), k=n_components)
mean_prior = init_means_kmeans.clone().to(device)
# Create a range for plotting the Gaussian distributions
x_range = np.linspace(-3, 5, 1000)
nrows, ncols = len(precision_values), 1
figsize = dynamic_figsize(nrows, ncols)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
fig.suptitle("Mean Prior Distributions (1D Visualization)", y=1)
colors = ['red', 'blue', 'green', 'orange']
component_labels = [f'Component {i+1}' for i in range(n_components)]
for ax, precision in zip(axs, precision_values):
# Standard deviation of the prior (inverse of precision)
prior_std = 1.0 / np.sqrt(precision)
# Plot Gaussian distributions for each component's mean prior
for i in range(n_components):
# Use the first dimension of the mean for visualization
mean_val = mean_prior[i, 0].cpu().numpy()
# Calculate Gaussian PDF
gaussian_pdf = (1 / (prior_std * np.sqrt(2 * np.pi))) * \
np.exp(-0.5 * ((x_range - mean_val) / prior_std)**2)
ax.plot(x_range, gaussian_pdf, color=colors[i], linewidth=2,
label=f'{component_labels[i]}')
# Mark the mean with a vertical line
ax.axvline(mean_val, color=colors[i], linestyle='--', alpha=0.7)
ax.set_title(f"Precision λ = {precision} (σ = {prior_std:.4f})")
ax.set_xlabel("Mean Value (First Dimension)")
ax.set_ylabel("Prior Density")
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Define a list of precision values to compare
precision_values = [1, 1e2, 1e3]
# Create the initial means using K-means on the CPU
init_means_kmeans = GMMInitializer.kmeans(X_tensor.cpu(), k=n_components)
mean_prior = init_means_kmeans.clone().to(device)
nrows, ncols = len(precision_values), 1
figsize = dynamic_figsize(nrows, ncols)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
for ax, precision in zip(axs, precision_values):
# Create a new GMM instance for each precision value
gmm_mean_prior = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=1000,
device=device,
mean_prior=mean_prior,
mean_precision_prior=precision,
random_state=random_state,
init_means=mean_prior,
)
gmm_mean_prior.fit(X_tensor)
title = f"Mean Prior (precision={precision})\nLL={gmm_mean_prior.lower_bound_:.2f}"
# Plot using mode 'means' with a base alpha of 0.5
plot_gmm(
X=X,
gmm=gmm_mean_prior,
ax=ax,
title=title,
ellipse_alpha=0.5,
ellipse_line_style='dashed',
show_initial_means=True,
ellipse_colors='blue',
ellipse_std_devs=[2],
mean_color='yellow',
mean_size=50,
initial_mean_size=50,
scale_alpha_by_weight=True,
scale_size_by_weight=False
)
ax.legend()
plt.tight_layout()
plt.show()
The Gamma Distribution and its Inverse¶
The Gamma distribution is given by the density function
$$ f(x; \alpha, \theta) = \frac{1}{\Gamma(\alpha)\,\theta^\alpha}\, x^{\alpha-1} \exp\left(-\frac{x}{\theta}\right), \quad x > 0, $$
where:
- $\alpha$ is the shape parameter,
- $\theta$ is the scale parameter,
- $\Gamma(\alpha)$ is the gamma function.
The mean of the Gamma distribution is
$$ E[X] = \alpha \theta. $$
If we define
$$ Y = \frac{1}{X}, $$
then using the standard transformation method the density of $Y$ becomes
$$ \begin{aligned} f_Y(y) &= f_X\left(\frac{1}{y}\right) \left|\frac{d}{dy}\left(\frac{1}{y}\right)\right| \\ &= \frac{1}{\Gamma(\alpha)\,\theta^\alpha} \left(\frac{1}{y}\right)^{\alpha-1} \exp\left(-\frac{1}{\theta y}\right) \frac{1}{y^2} \\ &= \frac{1}{\Gamma(\alpha)\,\theta^\alpha}\, y^{-\alpha-1} \exp\left(-\frac{1}{\theta y}\right), \quad y > 0. \end{aligned} $$
This density is of the form of an inverse Gamma distribution. A standard parameterization for the inverse Gamma distribution is:
$$ f_Y(y; \alpha, \beta) = \frac{\beta^\alpha}{\Gamma(\alpha)}\, y^{-\alpha-1} \exp\left(-\frac{\beta}{y}\right), \quad y > 0, $$
where $\beta > 0$ is the scale parameter of the inverse Gamma. By comparing the two forms, we identify
$$ \beta^\alpha = \frac{1}{\theta^\alpha} \quad \Longrightarrow \quad \beta = \frac{1}{\theta}. $$
Mean of the Inverse Gamma Distribution¶
For the inverse Gamma distribution with parameters $\alpha$ and $\beta$, the mean exists for $\alpha > 1$ and is given by
$$ E[Y] = \frac{\beta}{\alpha - 1}. $$
Substituting $\beta = \frac{1}{\theta}$ into the formula, we get
$$ E[Y] = \frac{1/\theta}{\alpha - 1} = \frac{1}{\theta (\alpha - 1)}. $$
Impact of the Parameters¶
Shape ($\alpha$):
The shape parameter $\alpha$ affects the tail behavior and the concentration of the distribution. Increasing $\alpha$ (with $\beta$ fixed) causes both the mean and the mode to decrease:$$ E[Y] = \frac{\beta}{\alpha - 1}, \qquad \text{mode} = \frac{\beta}{\alpha + 1}. $$
Thus, a higher $\alpha$ results in a distribution that is more concentrated near zero and has lighter tails, while a lower $\alpha$ produces heavier tails.
Scale ($\beta$):
The scale parameter $\beta$ shifts the distribution. A larger $\beta$ leads to higher values of the mean and mode. In many Bayesian contexts, the inverse gamma is obtained by inverting a Gamma-distributed variable with scale parameter $\theta$, where $\beta = \frac{1}{\theta}$. In that case, increasing $\theta$ (making the Gamma distribution more spread out) results in a smaller $\beta$, which in turn shifts the inverse gamma distribution toward lower values.
Visualizing the Inverse Gamma Distribution¶
The three panels below demonstrate the following scenarios:
Varying $\alpha$ with Fixed $\theta$:
How changes in the shape parameter alter the spread and concentration of the distribution.Fixed $\alpha$ with Varying $\theta$:
The effect of the scale parameter on the location and spread of the distribution.Varying Both to Maintain the Same Mode:
How different combinations of $\alpha$ and $\theta$ can yield the same mode, illustrating the trade-off between the parameters.
x = np.linspace(0.001, 2, 1000)
fig, axes = plt.subplots(3, 1, figsize=(10, 18))
# -------------------------------
# Plot 1: Varying $\alpha$, fixed $\theta$
# -------------------------------
fixed_theta = 1
alphas = [1, 2, 3, 7]
for alpha in alphas:
beta = 1 / fixed_theta
pdf_vals = invgamma.pdf(x, a=alpha, scale=beta)
mode = beta / (alpha + 1)
line, = axes[0].plot(x, pdf_vals, label=f"$\\alpha={alpha},\\,\\theta={fixed_theta},\\,\\mathrm{{mode}}={mode:.3f}$")
axes[0].axvline(mode, linestyle='dashed', color=line.get_color(), linewidth=1.5)
axes[0].set_title("Inverse Gamma PDF (Varying $\\alpha$, Fixed $\\theta$)")
axes[0].set_xlabel("$x$")
axes[0].set_ylabel("Density")
axes[0].legend()
axes[0].grid(True)
# -------------------------------
# Plot 2: Fixed $\alpha$, varying $\theta$
# -------------------------------
fixed_alpha = 3
thetas = [0.25, 0.5, 1, 2]
for theta in thetas:
beta = 1 / theta
pdf_vals = invgamma.pdf(x, a=fixed_alpha, scale=beta)
mode = beta / (fixed_alpha + 1)
line, = axes[1].plot(x, pdf_vals, label=f"$\\alpha={fixed_alpha},\\,\\theta={theta},\\,\\mathrm{{mode}}={mode:.3f}$")
axes[1].axvline(mode, linestyle='dashed', color=line.get_color(), linewidth=1.5)
axes[1].set_title("Inverse Gamma PDF (Fixed $\\alpha$, Varying $\\theta$)")
axes[1].set_xlabel("$x$")
axes[1].set_ylabel("Density")
axes[1].legend()
axes[1].grid(True)
# -------------------------------
# Plot 3: Varying both $\alpha$ and $\theta$ to get the same mode
# -------------------------------
alphas_equal = [1, 3, 7, 15]
thetas_equal = [1, 0.5, 0.25, 0.125]
for alpha, theta in zip(alphas_equal, thetas_equal):
beta = 1 / theta
pdf_vals = invgamma.pdf(x, a=alpha, scale=beta)
mode = beta / (alpha + 1)
line, = axes[2].plot(x, pdf_vals, label=f"$\\alpha={alpha},\\,\\theta={theta},\\,\\mathrm{{mode}}={mode:.3f}$")
axes[2].axvline(mode, linestyle='dashed', color=line.get_color(), linewidth=1.5)
axes[2].set_title("Inverse Gamma PDF (Varying both $\\alpha$ and $\\theta$ to get the same mode)")
axes[2].set_xlabel("$x$")
axes[2].set_ylabel("Density")
axes[2].legend()
axes[2].grid(True)
plt.tight_layout()
plt.show()
The Wishart Distribution and its Inverse¶
In Bayesian modeling of Gaussian mixtures, a key step is placing a prior on the covariance matrices of each Gaussian component. Since these matrices must be positive-definite (i.e., all eigenvalues are positive), the prior must respect this constraint. While the (non-inverse) Wishart distribution is a natural multivariate analogue of the Gamma distribution, it has characteristics that are not ideal for covariance priors in a MAP estimation context. Instead, the inverse Wishart (and its one-dimensional counterpart, the inverse Gamma) is commonly used. Let’s explore why.
From the Gamma to the Wishart Distribution¶
The Gamma Distribution (Scalar Case)¶
The Gamma distribution is defined as
$$ f(x; \alpha, \theta) = \frac{1}{\Gamma(\alpha)\,\theta^\alpha}\, x^{\alpha-1} \exp\left(-\frac{x}{\theta}\right), \quad x > 0, $$
where:
- $ \alpha $ is the shape parameter,
- $ \theta $ is the scale parameter, and
- $ \Gamma(\alpha) $ is the Gamma function.
For many Bayesian models, the Gamma distribution is used as a conjugate prior for positive parameters (e.g. precision). Notice that—even though certain parameter choices can put high density near zero—its support is strictly $ (0, \infty) $, so the probability of $ x=0 $ is exactly zero.
The Wishart Distribution (Multivariate Case)¶
When we generalize to matrices, the Wishart distribution arises naturally. Its density over $ d \times d $ positive-definite matrices is
$$ p(\mathbf{S}; \nu, \Psi) \propto |\mathbf{S}|^{\frac{\nu-d-1}{2}} \exp\left(-\frac{1}{2}\operatorname{tr}\left(\Psi^{-1} \mathbf{S}\right)\right), $$
where:
- $ \mathbf{S} $ is a positive-definite matrix,
- $ \nu $ is the degrees of freedom (playing a role similar to the shape parameter),
- $ \Psi $ is the scale matrix, and
- $ d $ is the dimension.
For $ d = 1 $, the Wishart reduces to the Gamma distribution (up to constants), which shows their close connection.
Why Use the Inverse Wishart and Inverse Gamma?¶
There are two main reasons for preferring the inverse distributions as priors on covariance matrices in a GMM:
Zero Probability at Zero
Even though a standard Gamma (or Wishart) distribution can be parameterized to place a lot of mass near zero, its density is not zero at $ x=0 $ (or at a singular covariance matrix). This is problematic for covariance matrices because a zero covariance (or a singular matrix) is not a valid parameter for a Gaussian distribution.
In contrast, the inverse Gamma and inverse Wishart distributions have support strictly on $ (0, \infty) $ or on the set of non-singular positive-definite matrices, meaning that the probability of a covariance parameter being exactly zero is exactly zero. This ensures that our prior never “allows” a zero (or degenerate) covariance, which is crucial for well-defined Gaussian likelihoods.
Longer Right-Hand Tail
The inverse distributions typically have a heavier (longer) tail on the right-hand side. This longer tail means that while the prior penalizes extremely small covariances (preventing singularity), it is more tolerant of large covariance estimates.
For covariance matrices, this is important because the empirical covariance can vary widely—especially in high-dimensional settings. The long tail ensures that the prior does not overly constrain large values, thus preserving flexibility. In 2D, for example, the determinant of the covariance matrix (which is proportional to the area of the corresponding ellipse) can be thought of as capturing the “size” of the uncertainty. A heavy right tail in the inverse Wishart better accommodates the possibility of large uncertainty, similar in spirit (though not strictly proportional) to how the surface area of a circle increases with its radius.
# Initial covariance matrix
covariance_matrix = np.array([[1, 0.5], [0.5, 1]])
# Degrees of freedom and scale factors
dfs = [2, 3, 4]
scales = [0.5, 1, 2]
scale_matrices = [covariance_matrix * scale for scale in scales]
# Collect inverse Wishart samples
inv_wishart_samples_list = []
for df in dfs:
for scale_matrix in scale_matrices:
# Sample from the Wishart distribution
samples = wishart.rvs(df=df, scale=scale_matrix, size=1000000)
# Invert the samples to obtain inverse Wishart samples
inv_samples = np.linalg.inv(samples)
inv_wishart_samples_list.append((df, scale_matrix, inv_samples))
# Set fixed axis limits (adjust as needed)
xmin, xmax = -0.01, 2
ymin, ymax = -0.01, 2
# Set up the figure for multiple hexbin plots
nrows, ncols = len(dfs), len(scale_matrices)
figsize = dynamic_figsize(nrows, ncols, base_width=7, base_height=6)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
fig.suptitle("Hexbin Plots with Smoothed Gradient Contours\nof 2D Inverse Wishart Distributions", fontsize=16)
# Loop over different degrees of freedom and scale matrices
idx = 0
for i, df in enumerate(dfs):
for j, scale_matrix in enumerate(scale_matrices):
df_current, scale_matrix_current, inv_samples = inv_wishart_samples_list[idx]
idx += 1
# Extract diagonal elements (variances) from the inverse Wishart samples
x = inv_samples[:, 0, 0]
y = inv_samples[:, 1, 1]
ax = axes[i, j]
# Create hexbin plot
hb = ax.hexbin(x, y, gridsize=30, cmap='PiYG', extent=[xmin, xmax, ymin, ymax])
fig.colorbar(hb, ax=ax, orientation='vertical')
# Compute a 2D histogram for contour lines
H, xedges, yedges = np.histogram2d(x, y, bins=20, range=[[xmin, xmax], [ymin, ymax]])
# Smooth the histogram using a Gaussian filter
H_smooth = gaussian_filter(H, sigma=1)
# Compute the centers of the bins
Xc = (xedges[:-1] + xedges[1:]) / 2
Yc = (yedges[:-1] + yedges[1:]) / 2
X_mesh, Y_mesh = np.meshgrid(Xc, Yc)
# Overlay smoothed contour lines to show density gradients
ax.contour(X_mesh, Y_mesh, H_smooth.T, colors='k', linewidths=1, alpha=1, levels=10)
ax.set_aspect('equal')
ax.set_xlim([xmin, xmax])
ax.set_ylim([ymin, ymax])
ax.set_title(f"DF={df}, Scale={np.round(scale_matrix[0,0],1)}")
ax.set_xlabel("Diagonal Element 1")
ax.set_ylabel("Diagonal Element 2")
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()
Covariance Priors in GMM¶
Next, we demonstrate how covariance priors affect the GMM. In the following two sections, we vary:
- Prior Strength: The covariance prior is scaled by different factors.
- Degrees of Freedom (DOF): The DOF of the prior distribution is varied.
- Strength and DOF: Varying both the prior strength and the DOF.
1. Varying Prior Strength¶
Here, the DOF is fixed and the covariance prior strength is varied.
data_covariance = np.cov(X_tensor.cpu().numpy(), rowvar=False) # shape (2, 2)
data_covariance = torch.tensor(data_covariance, dtype=torch.float32, device=device)
data_covariance = data_covariance.unsqueeze(0).expand(n_components, -1, -1)
degrees_of_freedom_prior = float(n_features + 2)
prior_strengths = [1, 5, 25]
nrows, ncols = len(prior_strengths), 1
figsize = dynamic_figsize(nrows, ncols)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
fig.suptitle("Covariance Priors: Varying Strength")
for ax, strength in zip(axs, prior_strengths):
cov_prior = data_covariance * strength
dummy_mean_prior = torch.zeros(n_components, n_features, device=device)
dummy_mean_precision_prior = 1e-10 # effectively no strong push on means
gmm_cov_prior = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=10000,
device=device,
# Covariance prior:
covariance_prior=cov_prior,
degrees_of_freedom_prior=degrees_of_freedom_prior,
# Provide "dummy" mean priors so code doesn't crash
mean_prior=dummy_mean_prior,
mean_precision_prior=dummy_mean_precision_prior
)
gmm_cov_prior.fit(X_tensor)
title = f"DOF={degrees_of_freedom_prior}, Strength={strength}\nLL={gmm_cov_prior.lower_bound_:.2f}"
# Use mode 'covariances' with the sequential colormap (default: Greens) and std_devs [1,2,3]
plot_gmm(
X=X,
gmm=gmm_cov_prior,
ax=ax,
title=title,
ellipse_colors='green',
ellipse_std_devs=[1, 2, 3],
)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()
2. Varying Degrees of Freedom (DOF)¶
Here, the covariance prior strength is fixed and we vary the DOF.
Note: For 2D, DOF must be greater than 1.
data_covariance = np.cov(X_tensor.cpu().numpy(), rowvar=False) # shape (2, 2)
data_covariance = torch.tensor(data_covariance, dtype=torch.float32, device=device)
data_covariance = data_covariance.unsqueeze(0).expand(n_components, -1, -1)
# Instead of varying strength, we fix the strength and vary DOF.
dof_values = [n_features + 8, n_features + 98, n_features + 498]
fixed_strength = 5
nrows, ncols = len(dof_values), 1
figsize = dynamic_figsize(nrows, ncols)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
fig.suptitle("Covariance Priors: Varying DOF")
for ax, dof in zip(axs, dof_values):
cov_prior = data_covariance * fixed_strength
dummy_mean_prior = torch.zeros(n_components, n_features, device=device)
dummy_mean_precision_prior = 1e-10 # negligible influence on means
gmm_cov_prior = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=10000,
device=device,
covariance_prior=cov_prior,
degrees_of_freedom_prior=dof,
mean_prior=dummy_mean_prior,
mean_precision_prior=dummy_mean_precision_prior
)
gmm_cov_prior.fit(X_tensor)
title = f"DOF={dof}, Strength={fixed_strength}\nLL={gmm_cov_prior.lower_bound_:.2f}"
# Use mode 'covariances' with the sequential colormap (default: Greens) and std_devs [1, 2, 3].
plot_gmm(
X=X,
gmm=gmm_cov_prior,
ax=ax,
title=title,
ellipse_colors='red',
ellipse_std_devs=[1, 2, 3]
)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()
3. Varying Degrees of Freedom (DOF)¶
Here, both the covariance prior strength and the DOF are varied.
# Compute data covariance from the tensor data.
data_covariance = np.cov(X_tensor.cpu().numpy(), rowvar=False) # shape (2, 2)
data_covariance = torch.tensor(data_covariance, dtype=torch.float32, device=device)
data_covariance = data_covariance.unsqueeze(0).expand(n_components, -1, -1)
# Define the lists of prior strengths and degrees of freedom to explore.
prior_strengths = [25, 10, 5]
dof_values = [n_features + 8, n_features + 98, n_features + 498]
# Create a 3x3 grid.
nrows, ncols = len(dof_values), len(prior_strengths)
figsize = dynamic_figsize(nrows, ncols, base_width=6, base_height=6)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
fig.suptitle("Covariance Priors: Varying Strength and DOF")
# Loop over DOF values (rows) and prior strengths (columns)
for i, dof in enumerate(dof_values):
for j, strength in enumerate(prior_strengths):
ax = axs[i, j]
# Compute covariance prior for this combination.
cov_prior = data_covariance * strength
dummy_mean_prior = torch.zeros(n_components, n_features, device=device)
dummy_mean_precision_prior = 1e-10 # negligible influence on means
gmm_cov_prior = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=10000,
device=device,
covariance_prior=cov_prior,
degrees_of_freedom_prior=dof,
mean_prior=dummy_mean_prior,
mean_precision_prior=dummy_mean_precision_prior
)
gmm_cov_prior.fit(X_tensor)
title = f"DOF={dof}, Strength={strength}\nLL={gmm_cov_prior.lower_bound_:.2f}"
# Plot using mode 'covariances'. This mode uses a sequential colormap (here 'Greens')
# to assign colors to the ellipses so that the inner (1-std) ellipse is the darkest.
plot_gmm(
X=X,
gmm=gmm_cov_prior,
ax=ax,
title=title,
ellipse_colors='orange',
ellipse_std_devs=[1, 2, 3]
)
# Add legend (if any labels are present; otherwise this will issue a warning)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()
Normal-Inverse-Wishart (NIW) Conjugate Prior¶
The Normal-Inverse-Wishart (NIW) prior is a conjugate prior for the parameters of a multivariate Gaussian distribution when both the mean $\boldsymbol{\mu}$ and covariance matrix $\boldsymbol{\Sigma}$ are unknown. This is fundamentally different from specifying separate priors for means and covariances.
Mathematical Formulation¶
For a multivariate Gaussian component with parameters $(\boldsymbol{\mu}, \boldsymbol{\Sigma})$, the NIW prior is defined as:
$$p(\boldsymbol{\mu}, \boldsymbol{\Sigma}^{-1}) = \mathcal{N}(\boldsymbol{\mu} | \boldsymbol{m}_0, (\lambda_0 \boldsymbol{\Sigma})^{-1}) \cdot \mathcal{W}(\boldsymbol{\Sigma}^{-1} | \nu_0, \boldsymbol{\Psi}_0^{-1})$$
Where:
- $\boldsymbol{m}_0$: prior mean (hyperparameter)
- $\lambda_0$: prior precision parameter (controls how tightly the mean is constrained around $\boldsymbol{m}_0$)
- $\nu_0$: degrees of freedom for the Wishart distribution
- $\boldsymbol{\Psi}_0$: scale matrix for the Wishart distribution
Key Properties of NIW Prior¶
Conjugacy: The posterior distribution is also NIW, making Bayesian inference analytically tractable.
Coupling: The mean and covariance are coupled through the precision parameter $\lambda_0$. When $\boldsymbol{\Sigma}$ is large (uncertain), the mean prior becomes less informative automatically.
Natural Parameterization: The NIW prior naturally captures the relationship between mean uncertainty and covariance uncertainty.
Why NIW ≠ Separate Mean + Covariance Priors¶
When using separate priors:
- Mean prior: $p(\boldsymbol{\mu}) = \mathcal{N}(\boldsymbol{\mu} | \boldsymbol{m}_0, \tau_0^2 \mathbf{I})$
- Covariance prior: $p(\boldsymbol{\Sigma}^{-1}) = \mathcal{W}(\boldsymbol{\Sigma}^{-1} | \nu_0, \boldsymbol{\Psi}_0^{-1})$
The key differences are:
Independence vs. Coupling: Separate priors assume independence between mean and covariance uncertainty, while NIW couples them naturally.
Prior Strength Interpretation: In NIW, $\lambda_0$ controls how much the mean prior strength depends on the current covariance estimate. In separate priors, mean and covariance prior strengths are fixed independently.
Bayesian Updates: NIW updates both parameters simultaneously with proper coupling, while separate priors update them independently.
Demonstration¶
We'll compare three scenarios:
- Mean Prior Only: Prior on means, MLE for covariances
- Covariance Prior Only: Prior on covariances, MLE for means
- NIW Prior: Coupled prior on both parameters
This will highlight how the coupling in NIW differs from independent priors.
# Set up data and parameters
np.random.seed(42)
torch.manual_seed(42)
# Create synthetic 2D data with clear structure
n_samples_per_component = 150
true_means = np.array([[2.0, 2.0], [-2.0, -2.0], [2.0, -2.0]])
true_covs = np.array([
[[0.5, 0.2], [0.2, 0.3]],
[[0.8, -0.1], [-0.1, 0.4]],
[[0.3, 0.0], [0.0, 0.6]]
])
# Generate data
X_list = []
for i, (mean, cov) in enumerate(zip(true_means, true_covs)):
X_component = np.random.multivariate_normal(mean, cov, n_samples_per_component)
X_list.append(X_component)
X = np.vstack(X_list)
X_tensor = torch.tensor(X, dtype=torch.float32, device=device)
n_samples, n_features = X.shape
n_components = 3
# Define NIW hyperparameters
lambda0 = 10 # Prior precision parameter (small = less coupling)
nu0 = n_features + 100 # Degrees of freedom
m0 = torch.zeros(n_components, n_features, device=device) # Prior mean
# Scale matrix (inverse of what we expect the covariance to be)
Psi0 = torch.eye(n_features, device=device).unsqueeze(0).expand(n_components, -1, -1) * 10
print("NIW Hyperparameters:")
print(f"λ₀ (precision parameter): {lambda0}")
print(f"ν₀ (degrees of freedom): {nu0}")
print(f"m₀ (prior mean): {m0[0].cpu().numpy()}")
print(f"Ψ₀ (scale matrix): \n{Psi0[0].cpu().numpy()}")
# Compare different prior configurations
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle("Comparison: Separate Priors vs. NIW Conjugate Prior", fontsize=16)
# 1. Mean Prior Only (top-left)
ax = axes[0, 0]
gmm_mean_only = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=1000,
device=device,
mean_prior=m0,
mean_precision_prior=lambda0,
# No covariance prior - uses MLE
)
gmm_mean_only.fit(X_tensor)
plot_gmm(
X=X,
gmm=gmm_mean_only,
ax=ax,
title=f"Mean Prior Only\nLL: {gmm_mean_only.lower_bound_:.2f}",
ellipse_colors='blue',
ellipse_std_devs=[1, 2]
)
# 2. Covariance Prior Only (top-right)
ax = axes[0, 1]
gmm_cov_only = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=1000,
device=device,
covariance_prior=Psi0,
degrees_of_freedom_prior=nu0,
# No mean prior - uses MLE
)
gmm_cov_only.fit(X_tensor)
plot_gmm(
X=X,
gmm=gmm_cov_only,
ax=ax,
title=f"Covariance Prior Only\nLL: {gmm_cov_only.lower_bound_:.2f}",
ellipse_colors='red',
ellipse_std_devs=[1, 2]
)
# 3. Both Priors Separately (bottom-left)
ax = axes[1, 0]
gmm_separate = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=1000,
device=device,
mean_prior=m0,
mean_precision_prior=lambda0,
covariance_prior=Psi0,
degrees_of_freedom_prior=nu0,
)
gmm_separate.fit(X_tensor)
plot_gmm(
X=X,
gmm=gmm_separate,
ax=ax,
title=f"Separate Mean + Cov Priors\nLL: {gmm_separate.lower_bound_:.2f}",
ellipse_colors='purple',
ellipse_std_devs=[1, 2]
)
# 4. NIW Conjugate Prior (bottom-right)
# Note: This uses the same hyperparameters but with proper NIW coupling
ax = axes[1, 1]
# In the current implementation, when both mean and covariance priors are provided,
# it should use NIW conjugate updates. Let's verify this is working correctly.
gmm_niw = GaussianMixture(
n_features=n_features,
n_components=n_components,
covariance_type='full',
max_iter=1000,
device=device,
mean_prior=m0,
mean_precision_prior=lambda0,
covariance_prior=Psi0,
degrees_of_freedom_prior=nu0,
)
gmm_niw.fit(X_tensor)
plot_gmm(
X=X,
gmm=gmm_niw,
ax=ax,
title=f"NIW Conjugate Prior\nLL: {gmm_niw.lower_bound_:.2f}",
ellipse_colors='green',
ellipse_std_devs=[1, 2]
)
plt.tight_layout()
plt.show()
# Print parameter comparison
print("\n" + "="*80)
print("PARAMETER COMPARISON")
print("="*80)
print(f"\n1. Mean Prior Only - Final Means:")
print(gmm_mean_only.means_.cpu().numpy())
print(f" Mean distances from prior: {torch.norm(gmm_mean_only.means_ - m0, dim=1).cpu().numpy()}")
print(f"\n2. Covariance Prior Only - Final Means:")
print(gmm_cov_only.means_.cpu().numpy())
print(f" Mean distances from prior: {torch.norm(gmm_cov_only.means_ - m0, dim=1).cpu().numpy()}")
print(f"\n3. Separate Priors - Final Means:")
print(gmm_separate.means_.cpu().numpy())
print(f" Mean distances from prior: {torch.norm(gmm_separate.means_ - m0, dim=1).cpu().numpy()}")
print(f"\n4. NIW Conjugate - Final Means:")
print(gmm_niw.means_.cpu().numpy())
print(f" Mean distances from prior: {torch.norm(gmm_niw.means_ - m0, dim=1).cpu().numpy()}")
NIW Hyperparameters: λ₀ (precision parameter): 10 ν₀ (degrees of freedom): 102 m₀ (prior mean): [0. 0.] Ψ₀ (scale matrix): [[10. 0.] [ 0. 10.]]
================================================================================ PARAMETER COMPARISON ================================================================================ 1. Mean Prior Only - Final Means: [[ 1.8937136 1.8988997] [ 1.9392073 -1.844294 ] [-1.8908288 -1.9037396]] Mean distances from prior: [2.681785 2.676181 2.6831806] 2. Covariance Prior Only - Final Means: [[ 2.0701602 -1.9716756] [-2.0188365 -2.0305052] [ 2.0182817 2.022977 ]] Mean distances from prior: [2.8588579 2.863329 2.8576035] 3. Separate Priors - Final Means: [[ 1.9392515 -1.8461971] [-1.8925295 -1.9034142] [ 1.8932719 1.8979424]] Mean distances from prior: [2.677525 2.6841486 2.6807954] 4. NIW Conjugate - Final Means: [[ 1.8932719 1.8979424] [ 1.9392515 -1.8461971] [-1.8925295 -1.9034142]] Mean distances from prior: [2.6807954 2.677525 2.6841486]