home

Superposition Experiment

This article is a continuation of the Superposition Briefly post.

In the previous article on superposition I briefly described the phenomenon. To get a more practical lens on how superposition occurs in practice I’ll be replicating the experiment Anthropic has done in their paper.

Before we even try to visualize and see superposition in practice, we need to identify what kind of data we should be working with. We have the following premises regarding data that’ll lead to interpretable results:

We set up a small model with n=20n = 20 and m=5m = 5, where nn is the number of features and mm is the number of dimensions our model has. We also need to vary the sparsity level and assign different importance to each feature.

As for the synthetic data, the input vectors xx simulate the mentioned premises. Every xix_i (which is a “feature”) has an associated sparsity SS and importance IiI_i. Every xix_i equals 00 with probability SS and is uniformly distributed between [0,1][0, 1] otherwise. As for the importance, the paper uses geometric decay: Ii=0.7iI_i = 0.7^i. 0.70.7 is an arbitrarily chosen base and isn’t a magic number. Looking at II:

I=[1.00.70.490.340.240.719]I = \begin{bmatrix} 1.0 & 0.7 & 0.49 & 0.34 & 0.24 & \dots & 0.7^{19} \end{bmatrix}

Importance affects the loss — errors on more significant features are penalized more heavily, so the model prioritizes representing them. The loss is:

L=i=0n1Ii(xix^i)2\mathcal{L} = \sum_{i=0}^{n-1} I_i (x_i - \hat{x}_i)^2

Now that we have identified the loss function — what exactly are we trying to minimize the loss for?

The model tries to reconstruct the embeddings of xx with nn features via mm-dimensional space. The model looks like this:

h=Wxx^=ReLU(WTh+b)=ReLU(WTWx+b)\begin{aligned} h &= W x \\ \hat{x} &= \operatorname{ReLU}(W^T h + b) \\ &= \operatorname{ReLU}(W^T W x + b) \end{aligned}

The paper hypothesis suggests that every feature in the nn-dimensional space can be represented in the lower mm-dimensional one. We are using linear map WW, where WRm×nW \in \mathbb{R}^{m \times n} is the weight matrix. Each column WiW_i represents the direction of the feature xix_i.

We use the transpose of the matrix WTW^T to recover the original vector.

We also include bias to the recovered result. The reason for doing so is to allow the model to nudge the features to their expected values.

Our model also uses an activation function. This seems to be important for superposition. As I read the paper, it was unintuitive for me at first why it is so important for the model to add non-linearity to superpose.

Visualization

W^T W across different sparsities

We see that in the densest case (S=0S = 0), only diagonal entries are highlighted for the 5 most important features. As sparsity increases, it starts to represent more features — but more noise also emerges.

Analytical insight

Besides showing the actual loss L\mathcal{L} that would be computed while training, the paper analytically explains why superposition is occuring showing this equation:

Ex[L]iIi(1Wi2)2feature benefit+ijIj(WjWi)2interference\mathbb{E}_x[\mathcal{L}] \sim \underbrace{\sum_i I_i \left(1 - \|W_i\|^2\right)^2}_{\text{feature benefit}} + \underbrace{\sum_{i \neq j} I_j \left(W_j \cdot W_i\right)^2}_{interference}

Feature benefit is …

Interference

Full deriviation: from MSE to feature benefit + interference

We start deriving from our original MSE loss:

L=i=0n1Ii(xix^i)2\mathcal{L} = \sum_{i=0}^{n-1} I_i (x_i - \hat{x}_i)^2

Now we start substituting the value of x^i\hat{x}_i relative to the xx:

x^i=(WTWx)i\hat{x}_i = (W^T W x)_i

Knowing that (WTW)ij=WiWj(W^T W)_{ij} = W_i \cdot W_j we replace matrix multiplication with explicit sum:

x^i=(WTWx)i=j(WTW)ijxj=j(WiWj)xj=(WiWi)xi+ij(WiWj)xj\begin{aligned} \hat{x}_i &= (W^T W x)_i \\ &= \sum_{j} (W^T W)_{ij} x_j \\ &= \sum_{j} (W_i \cdot W_j) x_j \\ &= (W_i \cdot W_i) x_i + \sum_{i \neq j} (W_i \cdot W_j) x_j \\ \end{aligned}

Since WiWi=Wi2W_i \cdot W_i = \lVert W_i \rVert^2 we substitute that for i=ji = j case:

x^i=Wi2xi+ij(WiWj)xj\begin{aligned} \hat{x}_i &= \lVert W_i \rVert^2 x_i + \sum_{i \neq j} (W_i \cdot W_j) x_j \end{aligned}

Following the original loss equation:

xix^i=xiWi2xiij(WiWj)xj=xi(1Wi2)ij(WiWj)xj\begin{aligned} x_i - \hat{x}_i &= x_i - \lVert W_i \rVert^2 x_i - \sum_{i \neq j} (W_i \cdot W_j) x_j \\ &= x_i (1 - \lVert W_i \rVert^2) - \sum_{i \neq j} (W_i \cdot W_j) x_j \\ \end{aligned}