Autoregressive next token prediction and KV Cache in transformers

coarchitect1 pts0 comments

Autoregressive next token prediction & KV Cache in transformers | by Frederik vom Lehn | Advanced Deep Learning | May, 2026 | MediumSitemapOpen in appSign up<br>Sign in

Medium Logo

Get app<br>Write

Search

Sign up<br>Sign in

Advanced Deep Learning

Deep learning is a subset of machine learning focused on training artificial neural networks to automatically learn and extract hierarchical representations from data.

Autoregressive next token prediction & KV Cache in transformers

Frederik vom Lehn

7 min read·<br>1 hour ago

Listen

Share

Understand the optimization technique in LLMs to speed up token generation

Press enter or click to view image in full size

The general overview (Image by author).

The Big Picture<br>Before we dive into attention heads, KV caches, and the mechanics of generation, it helps to zoom out and see what an autoregressive language model actually is at a glance.<br>A prompt enters as plain text: “How are you?”. A tokenizer chops it into vocabulary IDs — here 3, 7, 1, 9, prefixed with a BOS ("beginning of sequence") token. Each ID is just an integer pointing into a lookup table : a learned matrix of shape (vocab_size, c) where every row is the embedding vector for one token in the vocabulary. Selecting the rows for our 5 input IDs produces X, a (5, 4) matrix, five tokens, each living in a 4-dimensional embedding space. This is where text leaves the world of symbols and enters the world of vectors. We use toy dimensions for our examples here.<br>From here, X flows through a stack of decoder blocks . Each block is the same architecture, multi-head self-attention followed by an MLP, and each block transforms its input into a refined (5, 4) representation of the same shape. The trick that makes deep transformers trainable is the residual connection wrapped around every block: instead of replacing the input, each block adds to it (X₁ = X + block_output). Information flows along a continuous "residual stream" that each layer edits rather than overwrites. Stack three of these and you get X₃, the final hidden state.<br>The last step inverts the first. The unembedding matrix, often the lookup table transposed, since input and output vocabularies are the same, projects each row of X₃ back into vocabulary space, producing a (5, 12) logits matrix: a score for every vocabulary token at every position. For next-token generation, only the last row matters. Its argmax is the token the model wants to say next. Here, that's token ID 5.<br>That’s the whole forward pass at altitude. The rest of this article zooms in on what happens inside one of those decoder blocks and on the optimization, KV caching , that makes generating long sequences feasible at all.<br>Let's zoom in and check what happens inside one layer during the first forward pass inside a single decoding layer.

Press enter or click to view image in full size

The Prefill Forward Pass (Image by author)

The Prefill Forward Pass<br>Before a language model can generate a single new token, it has to process the prompt. This step (prefill) runs the entire input sequence through the network in one parallel forward pass. Its job is twofold: produce the first predicted token, and populate the KV cache so that subsequent decode steps stay cheap.<br>Let’s walk through what happens to a 5-token prompt in a tiny model with hidden dimension c = 4, 2 attention heads, and a vocabulary of 12 tokens.<br>From tokens to Q, K, V<br>The input X arrives as a (5, 4) matrix: 5 tokens, each represented by a 4-dimensional embedding pulled from the lookup table. Three learned projection matrices Wq, Wk, Wv, each of shape (4, 4), transform X into the query, key, and value matrices Q, K, V, all of shape (5, 4).<br>Because we have 2 heads, each (5, 4) matrix is split column-wise into two (5, 2) slices, one slice per head. Each head will compute attention independently in its own 2-dimensional subspace.<br>Attention within a head<br>Inside a single head, attention is a weighted lookup. The head’s Q slice (5, 2) is multiplied by the transpose of its K slice to produce a (5, 5) matrix of attention scores — every token's query dotted with every token's key. After scaling and softmax (and a causal mask, since this is an autoregressive model, token t must not see tokens > t), each row of this matrix becomes a probability distribution over "which past tokens should I pull information from."<br>These weights then multiply the head’s V slice (5, 2), yielding the head's output of shape (5, 2): each token now holds a context-aware mix of value vectors from its allowed positions.<br>Concatenation and projection<br>The two heads’ outputs are concatenated back into a (5, 4) matrix, then passed through an output projection (4, 4). The result, X', is again (5, 4), same shape as the input, but every row now reflects information gathered from across the sequence.<br>The MLP<br>Each token’s vector is then sent independently through a two-layer MLP. W_up of shape (4, 8) expands each row to 8 dimensions, GeLU adds non-linearity, and W_down of shape...

token matrix shape head from attention

Related Articles