Skip to content

Variational Mutual Information Estimation

1. The Challenge of Estimating Mutual Information

Mutual Information (MI) \(I(X; Y)\) is a fundamental measure of the nonlinear dependence between two continuous random variables.

\[ I(X; Y) = \mathbb{E}_{p(x,y)}\left[ \log \frac{p(x,y)}{p(x)p(y)} \right] \]

In high dimensions, estimating MI directly using histograms, KDE, or \(k\)-NN estimators fails due to the curse of dimensionality. Recent breakthroughs in deep learning introduced Variational Mutual Information Estimators. These methods use neural networks to parameterize lower bounds on MI. By maximizing these lower bounds via gradient ascent, the neural network learns to act as a highly accurate MI estimator.

2. The NWJ (Nguyen-Wainwright-Jordan) Lower Bound

The NWJ bound (also known as the MINE-f bound) utilizes the \(f\)-divergence formulation to bound the Kullback-Leibler divergence.

Theorem 2.1 (The NWJ Bound on Mutual Information)

For any function \(T(x,y)\) parametrized by a neural network \(\theta\):

\[ I(X; Y) \ge \sup_{\theta} \mathbb{E}_{p(x,y)}[T_\theta(X, Y)] - e^{-1} \mathbb{E}_{p(x)p(y)}[e^{T_\theta(X, Y)}] \]

Proof: Mutual information is the KL divergence between the joint distribution \(P = p(x,y)\) and the product of marginals \(Q = p(x)p(y)\): \(I(X; Y) = D_{KL}(P || Q)\). The KL divergence can be expressed using its variational dual representation (the Donsker-Varadhan representation). For any function \(T\):

\[ D_{KL}(P || Q) \ge \mathbb{E}_P[T(X,Y)] - \log \mathbb{E}_Q[e^{T(X,Y)}] \]

The NWJ bound is a slightly looser, but unconditionally unbiased, variant. We start with the convex conjugate (Fenchel dual) of the function \(f(u) = u \log u\) which defines KL divergence in the \(f\)-divergence family. The conjugate is \(f^*(v) = e^{v-1}\). By Fenchel-Rockafellar duality for \(f\)-divergences:

\[ D_f(P || Q) \ge \sup_T \left( \mathbb{E}_P[T(X,Y)] - \mathbb{E}_Q[f^*(T(X,Y))] \right) \]

Substitute \(f^*(T) = e^{T-1} = e^{-1} e^T\):

\[ I(X; Y) \ge \mathbb{E}_{p(x,y)}[T_\theta(X,Y)] - \mathbb{E}_{p(x)p(y)}[e^{T_\theta(X,Y) - 1}] \]

This matches the NWJ bound. The beauty of this bound is that the expectation over \(Q\) (the product of marginals) does not require a logarithm around the expectation, meaning an empirical average is strictly unbiased. The optimal function \(T^*(x,y)\) that achieves equality is \(T^*(x,y) = \log \frac{p(x,y)}{p(x)p(y)} + 1\). \(\blacksquare\)

3. InfoNCE and Contrastive Learning

The InfoNCE (Noise-Contrastive Estimation) bound is the backbone of modern Self-Supervised Learning (SSL) algorithms like SimCLR and CLIP. It bounds MI by treating estimation as a classification problem.

Theorem 3.1 (The InfoNCE Bound)

Given a joint sample \((x_i, y_i) \sim p(x,y)\) and \(K-1\) negative samples \(y_j \sim p(y)\) drawn independently from the marginal, the InfoNCE loss for a critic function \(f(x,y)\) provides a lower bound:

\[ I(X; Y) \ge \log K - \mathcal{L}_{InfoNCE} \]

where

\[ \mathcal{L}_{InfoNCE} = - \mathbb{E}\left[ \log \frac{e^{f(x_i, y_i)}}{e^{f(x_i, y_i)} + \sum_{j=1}^{K-1} e^{f(x_i, y_j)}} \right] \]

Proof: Consider a set of \(K\) pairs, where exactly one pair is drawn from the joint distribution \(p(x,y)\), and the remaining \(K-1\) pairs are drawn from the marginal product \(p(x)p(y)\). Let the index of the true pair be \(I \in \{1, \dots, K\}\). Given the set of \(K\) samples \(S = \{(x, y_1), \dots, (x, y_K)\}\), we want to predict \(I\). The optimal Bayesian classifier computes the posterior probability:

\[ p(I=i | S) = \frac{p(S | I=i) p(I=i)}{\sum_{j=1}^K p(S | I=j) p(I=j)} \]

Assuming a uniform prior \(p(I=i) = 1/K\), and writing the likelihood: \(p(S | I=i) = p(x, y_i) \prod_{k \neq i} p(x) p(y_k)\). Divide numerator and denominator by \(\prod_{k=1}^K p(x) p(y_k)\):

\[ p(I=i | S) = \frac{ \frac{p(x, y_i)}{p(x)p(y_i)} }{ \sum_{j=1}^K \frac{p(x, y_j)}{p(x)p(y_j)} } \]

We train a neural network \(f(x,y)\) to identify the positive pair using categorical cross-entropy. The optimal \(f^*(x,y)\) converges to \(\log \frac{p(x,y)}{p(x)p(y)}\). The minimum categorical cross entropy loss is the conditional entropy \(H(I | S)\):

\[ H(I | S) = - \mathbb{E}\left[ \log \frac{ \frac{p(x, y_i)}{p(x)p(y_i)} }{ \frac{p(x, y_i)}{p(x)p(y_i)} + \sum_{j \neq i} \frac{p(x, y_j)}{p(x)p(y_j)} } \right] \]

Since \(I\) is uniform over \(K\) choices, \(H(I) = \log K\). The mutual information \(I(I; S) = H(I) - H(I|S) \ge 0\), which means \(H(I|S) \le \log K\). It can be strictly proven through Jensen's inequality that:

\[ I(X; Y) \ge \log K - H(I|S) \]

Substituting the network \(f(x,y)\) in place of the density ratio yields \(\mathcal{L}_{InfoNCE} \ge H(I|S)\), proving that \(\log K - \mathcal{L}_{InfoNCE}\) is a strict lower bound on \(I(X; Y)\). \(\blacksquare\)

The \(O(\log K)\) Limit

A critical limitation of InfoNCE is that the bound can never exceed \(\log K\). If the true MI is 15 nats, but we only use a batch size of \(K=128\), the maximum bound we can reach is \(\log(128) \approx 4.85\) nats. To accurately estimate high MI, enormous batch sizes are required.

4. Worked Examples

Example 1: Calculating optimal NWJ Critic

Let \(X \sim \mathcal{N}(0, 1)\) and \(Y = X + Z\) where \(Z \sim \mathcal{N}(0, \sigma^2)\). The joint distribution is Gaussian, and \(I(X; Y) = \frac{1}{2} \log(1 + 1/\sigma^2)\). The theoretical optimal critic \(T^*(x,y)\) for the NWJ bound is \(1 + \log \frac{p(x,y)}{p(x)p(y)}\). For Gaussians, the density ratio \(\frac{p(x,y)}{p(x)p(y)}\) is proportional to \(\exp\left( \frac{x \cdot y}{\sigma^2} + \dots \right)\). Thus, the optimal neural network critic function \(T^*(x,y)\) will be a quadratic function of \(x\) and \(y\). A simple linear layer will fail to maximize the bound!

Example 2: SimCLR and InfoNCE

In SimCLR, an image \(x\) undergoes two random augmentations to produce \(v_1\) and \(v_2\). These act as the positive pair \((x_i, y_i)\). Other images in the batch form the negative pairs. By minimizing \(\mathcal{L}_{InfoNCE}\) on these embeddings, SimCLR is literally maximizing a lower bound on the mutual information between different augmentations of the same image, forcing the network to learn robust, invariant semantic features.

Example 3: When does InfoNCE fail?

If the data distribution is very noisy (low mutual information), the InfoNCE bound is quite tight. However, in highly deterministic environments (like robotics simulation or invertible layers), MI approaches infinity. InfoNCE will rapidly hit the \(\log K\) ceiling. At this point, the loss gradient goes to zero, and the network stops learning, failing to capture finer details.

5. Coding Demos

Demo 1: The NWJ Mutual Information Estimator in PyTorch

This demo trains a simple neural network to estimate the MI between two correlated Gaussians.

Python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class Critic(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        # Input is concatenation of x and y
        self.net = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, y):
        # Concatenate along feature dimension
        xy = torch.cat([x, y], dim=-1)
        return self.net(xy)

def train_nwj(rho, epochs=500, batch_size=512):
    # Covariance matrix for correlated Gaussians
    cov = np.array([[1.0, rho], [rho, 1.0]])
    true_mi = -0.5 * np.log(1 - rho**2)

    critic = Critic()
    optimizer = optim.Adam(critic.parameters(), lr=1e-3)

    for epoch in range(epochs):
        # Sample joint P(x,y)
        xy_joint = np.random.multivariate_normal([0,0], cov, batch_size)
        x_joint = torch.tensor(xy_joint[:, 0:1], dtype=torch.float32)
        y_joint = torch.tensor(xy_joint[:, 1:2], dtype=torch.float32)

        # Sample marginal Q(x)Q(y) by shuffling y
        y_marginal = y_joint[torch.randperm(batch_size)]

        # Compute NWJ terms
        t_joint = critic(x_joint, y_joint)
        t_marginal = critic(x_joint, y_marginal)

        # NWJ Loss to minimize (negative of the bound)
        # Expected value under P: mean of t_joint
        # Expected value under Q: mean of exp(t_marginal - 1)
        loss = -(torch.mean(t_joint) - torch.mean(torch.exp(t_marginal - 1.0)))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    estimated_mi = -loss.item()
    print(f"True MI: {true_mi:.4f} nats, NWJ Estimated MI: {estimated_mi:.4f} nats")

# Test with correlation 0.8
train_nwj(rho=0.8)
Text Only
True MI: 0.5108 nats, NWJ Estimated MI: 0.5416 nats

Demo 2: InfoNCE Loss Implementation

A robust implementation of the InfoNCE loss for a batch of embeddings.

Python
import torch
import torch.nn.functional as F
import numpy as np

def infonce_loss(features_a, features_b, temperature=0.1):
    """
    Computes InfoNCE loss given two views of a batch.
    features_a: (Batch, Dim)
    features_b: (Batch, Dim)
    """
    batch_size = features_a.size(0)

    # Normalize features to unit sphere (Cosine Similarity)
    features_a = F.normalize(features_a, dim=-1)
    features_b = F.normalize(features_b, dim=-1)

    # Compute similarity matrix (Batch x Batch)
    # Entry (i, j) is sim(a_i, b_j)
    sim_matrix = torch.matmul(features_a, features_b.T) / temperature

    # The positive pairs are on the diagonal: a_i and b_i
    # Labels are [0, 1, 2, ..., Batch-1]
    labels = torch.arange(batch_size).to(features_a.device)

    # InfoNCE is standard Cross Entropy on this similarity matrix!
    loss = F.cross_entropy(sim_matrix, labels)

    # The MI lower bound is log(K) - loss
    mi_bound = np.log(batch_size) - loss.item()

    return loss, mi_bound

# Simulate embeddings
B, D = 64, 128
# Positives are close
fa = torch.randn(B, D)
fb = fa + torch.randn(B, D) * 0.1 

loss, mi = infonce_loss(fa, fb)
print(f"Batch Size: {B}, Max possible bound: {np.log(B):.3f}")
print(f"InfoNCE Loss: {loss.item():.3f}")
print(f"MI Lower Bound: {mi:.3f} nats")
Text Only
Batch Size: 64, Max possible bound: 4.159
InfoNCE Loss: 0.004
MI Lower Bound: 4.154 nats