Skip to content

4.5 Tensor Programs and Scaling

Introduction

The Master Theorem of Tensor Programs provides a rigorous calculus for computing the infinite-width limit of any neural network architecture. By representing the computation graph of a network as a sequence of matrix-vector multiplications and non-linearities, Tensor Programs map standard network operations directly into precise asymptotic behaviors.

Tensor Programs Framework

A Tensor Program is defined by a set of vectors \(V\) and a set of random matrices \(W\). Vectors are generated iteratively via matrix multiplications \(h = Wx\) and coordinate-wise non-linearities \(x = \phi(h_1, \dots, h_k)\).

Theorem 4.5.1 (The Master Theorem)

Let \(h^{(1)}, \dots, h^{(L)}\) be vectors generated by a Tensor Program. As the dimension \(n \to \infty\), the empirical joint distribution of their coordinates converges in Wasserstein-2 distance to a multivariate Gaussian distribution:

\[ \frac{1}{n} \sum_{i=1}^n \psi(h_i^{(1)}, \dots, h_i^{(L)}) \xrightarrow{n \to \infty} \mathbb{E}_{(Z_1, \dots, Z_L) \sim \mathcal{N}(0, \Sigma)} [\psi(Z_1, \dots, Z_L)] \]

where \(\Sigma\) is a deterministic covariance matrix determined recursively by the program's structure.

Proof: The proof proceeds by induction on the steps of the Tensor Program. For the base case, initial vectors are Gaussian by definition. For the inductive step, applying a random matrix \(W\) with i.i.d. Gaussian entries to a vector \(x\) yields \(h = Wx\). By the central limit theorem, the coordinates of \(h\) are Gaussian. The crux of the proof handles the dependencies introduced when \(W\) is reused. When \(W\) is applied multiple times, \(Wx_1\) and \(Wx_2\) are jointly Gaussian. The recursive covariance \(\Sigma\) captures the exact geometry of the vectors before multiplication. Concentration of measure ensures that empirical averages of continuous functions converge to their Gaussian expectations. \(\blacksquare\)

Maximal Update Parametrization (muP)

The standard parameterization (SP) and NTK parameterization fail to scale optimally. SP diverges at infinite width if the learning rate is not carefully tuned, while NTK freezes the features.

Theorem 4.5.2 (Feature Learning in muP)

The Maximal Update Parametrization (muP) defines the learning rates and initializations such that: 1. The network activations remain \(O(1)\) throughout training. 2. The weight updates \(\Delta W\) induce an \(O(1)\) change in the feature representations (unlike NTK). 3. The infinite-width limit is well-defined and exhibits non-trivial feature learning.

Proof: In muP, the learning rate for hidden weights \(W\) scales as \(\eta \propto n\). The initialization of \(W\) scales as \(1/\sqrt{n}\). The pre-activations \(h = \frac{1}{\sqrt{n}} W x\) are \(O(1)\). During backpropagation, the gradients \(\nabla_W \mathcal{L} \propto 1/\sqrt{n}\). With a learning rate of \(O(n)\), the update is \(\Delta W = O(\sqrt{n})\). Crucially, when \(\Delta W\) is applied to the forward pass, the change in features is \(\Delta h = \frac{1}{\sqrt{n}} \Delta W x = \frac{1}{\sqrt{n}} O(\sqrt{n}) = O(1)\). Thus, the features evolve by an \(O(1)\) amount in a single step, enabling rich feature learning without diverging, proving that muP uniquely enables infinite-width feature learning. \(\blacksquare\)

Hyperparameter Transfer

Because the dynamics under muP are invariant to the width \(n\), the optimal hyperparameters (learning rate, weight decay) remain strictly constant across all widths.

Worked Examples

Example 1: MLP in Tensor Programs An MLP is defined by \(h^0 = W_{in} x\), \(x^1 = \phi(h^0)\), \(h^1 = W_1 x^1\). By the Master Theorem, \(h^1_i \sim \mathcal{N}(0, \mathbb{E}[\phi(Z)^2])\) where \(Z \sim \mathcal{N}(0, \|x\|^2)\).

Example 2: SP vs muP Learning Rates In SP, \(\eta \propto 1\). In muP, \(\eta_{hidden} \propto n\). This means that to train a 10x wider model with SP, one often heuristically shrinks the learning rate. Under muP, you use exactly the same learning rate.

Example 3: Attention under muP For self-attention, the query-key dot product in standard Transformers scales as \(1/\sqrt{d}\). In muP, to ensure the attention logits are \(O(1)\) and the gradients flow correctly, the attention scaling is adjusted to \(1/d\).

Coding Demos

Demo 1: Hyperparameter Transfer Simulation (muP)

Python
import numpy as np

# Simulate feature update magnitudes across widths
widths = [100, 1000, 10000]

for n in widths:
    x = np.random.randn(n) / np.sqrt(n) # Input

    # Standard Param (SP)
    grad_sp = np.random.randn(n, n) / np.sqrt(n) # simplified gradient
    lr_sp = 1.0 # SP learning rate is O(1)
    delta_h_sp = (lr_sp * grad_sp) @ x

    # muP — gradient scales as 1/n, LR as n ⇒ ΔW entries O(1)
    grad_mup = np.random.randn(n, n) / n       # simplified gradient (O(1/n))
    lr_mup = n                                  # muP learning rate is O(n)
    delta_h_mup = (lr_mup * grad_mup) @ x / np.sqrt(n)

    print(f"Width {n}: SP Update = {np.linalg.norm(delta_h_sp):.4f}, "
          f"muP Update = {np.linalg.norm(delta_h_mup):.4f}")
# Both SP and muP produce O(1) feature updates, but muP keeps them stable
# at all widths while SP would require manual LR tuning.

Text Only
Width 100: SP Update = 0.8587, muP Update = 0.8429
Width 1000: SP Update = 1.0172, muP Update = 0.9963
Width 10000: SP Update = 0.9963, muP Update = 0.9837

Demo 2: Feature Learning Test

Python
import numpy as np

def test_feature_learning(width, param_type='mup'):
    n = width
    x = np.random.randn(n)
    W = np.random.randn(n, n) / np.sqrt(n)

    if param_type == 'ntk':
        delta_W = np.random.randn(n, n) / np.sqrt(n) * 1.0 # O(1) LR
        h_init = (W @ x) / np.sqrt(n)
        h_new = ((W + delta_W) @ x) / np.sqrt(n)
    elif param_type == 'mup':
        delta_W = np.random.randn(n, n) / np.sqrt(n) * n # O(n) LR
        h_init = (W @ x) / np.sqrt(n)
        h_new = ((W + delta_W) @ x) / np.sqrt(n) / np.sqrt(n) # proper scaling

    return np.linalg.norm(h_new - h_init)

print(f"NTK Feature Shift (n=10k): {test_feature_learning(10000, 'ntk'):.4f}")
print(f"muP Feature Shift (n=10k): {test_feature_learning(10000, 'mup'):.4f}")

Text Only
NTK Feature Shift (n=10k): 0.9748
muP Feature Shift (n=10k): 98.9391