Forward Self Models

E-Reverance1 pts0 comments

Forward Self-Models Learn an Empirical Approximation of Neural Network Computation | Jasper Gilley

← All Posts

Forward Self-Models Learn an Empirical Approximation of Neural Network Computation

Jasper Gilley

We introduce forward self-models : small networks trained to predict a neural network's later-layer activations from its earlier-layer activations, learning an empirical approximation of the computational function that the intervening layers implement. We demonstrate this by training forward self-models with sizes of 1-3% of main model parameters on main models ranging from 30M to 1B parameters, achieving up to 97% cosine similarity with the target activations and up to 94% recovery of the KL divergence of a layer's contribution to the output distribution. Forward self-model prediction errors are interpretably structured and track the computational complexity gap between the main model and forward model. We argue that forward self-models provide a primitive for both mechanistic interpretability and future model architectures that require an explicit model of their own computational dynamics. Code is available at https://github.com/jagilley/forward-self-models.

1. Introduction

A transformer's layers transform representations sequentially through the residual stream. Each layer applies attention and feedforward computation, producing output activations that become the next layer's input. The computational function each layer implements is central to understanding how the model works, but it is difficult to characterize directly. Existing approaches study this function indirectly: probing identifies what information is linearly accessible at each layer, and ablation measures each layer's causal contribution to the output. These approaches characterize what information exists and where it matters, but say less about the transformation itself.

We introduce a technique that directly approximates layers' computational function. A forward self-model is a small auxiliary network trained to predict a main model's later-layer activations from its earlier-layer activations. It learns through observation alone how the intervening layers transform their inputs. Because the forward model is deliberately small (1-3% of the main model's parameters), its approximation is imperfect, and what it captures is specifically the compressible component of the layer's computation. What it misses, the prediction residual, reveals the 'computational novelty' reflecting which aspects of the computation are genuinely hard to compress.

We train forward self-models on language modeling tasks at various model sizes. On our own 30M parameter GPT, a forward model predicting a single transformer layer achieves 0.97 cosine similarity with the target layer's activations. Attention pattern similarity is near-perfect (>0.98 cosine for 3 of 4 heads) while weight similarity is zero due to the gauge symmetry of the attention mechanism under rotations of the projection matrices. Its prediction errors track computational complexity (distributed attention, long-range context matching) rather than prediction difficulty. In causal substitution, replacing the layer with the forward model's prediction recovers 94% of the layer's KL contribution, with uniform degradation across several behavioral categories.

At Llama 3.2 1B scale, a 26.2M-parameter forward model (2.1% of Llama) achieves 0.94 cosine similarity with a layer's computation. Per-head decomposition reveals that the prediction error concentrates in the MLP rather than the attention heads, recovering a meaningful decomposition of which computations in a 1.24B-parameter model are predictable from the preceding layer state and which are not.

Our main contributions are:

The forward self-model technique for learning empirical approximations of neural network layer computation.

Causal fidelity: forward models can substitute for the layers they predict with minimal behavioral degradation, recovering 74-94% of the replaced layers' KL contribution at both 29M and 1.24B parameter scale.

Interpretable residual structure: the prediction residual tracks computational complexity rather than prediction difficulty.

A meaningful decomposition of Llama 3.2 1B layer computation, recovered by a 2.1%-parameter forward model through MSE optimization on frozen activations alone, with the prediction error mapping onto the architectural boundary between attention and the MLP.

A dissociation between representation and computation in the forward self-model objective: conditioning on a model's intermediate representations allows the forward model to compress the computational function that the intervening layers implement, rather than the joint representation-computation problem that conventional model compression addresses. This dissociation explains both the technique's parameter efficiency and the interpretability of its prediction errors, and suggests forward self-models as a primitive for...

model forward layer self models computation

Related Articles