petite-vllm Part 2: KV Cache & Paged Attention

matt_d1 pts0 comments

petite-vllm Part 2: KV Cache & Paged Attention — Kristen McIntosh

technical

KV Cache

In part 1 we implemented a simple autoregressive loop and LLM interface. There was no KV caching, which means that at each token generation step, we would recompute the previous KKK and VVV projections.

for tok_id in range(max_tokens):<br>positions = torch.arange(toks.shape[1]) # get positions based on current_seq_len

all_logits = self.model.forward(toks, positions)<br>last_logits = all_logits[:, -1, :]<br>new_tkn = sample(last_logits, temperature, top_k)<br>toks = torch.cat([toks, new_tkn.unsqueeze(0)], dim=-1)<br>To get some intuition on why we need a KV cache, lets take a quick detour into the details of attention and establish some context.

The attention formula:

Attention(Q,K,V)=Softmax(QKTdhead)V\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_{\text{head}}}}\right) V<br>Attention(Q,K,V)=Softmax(dhead​​QKT​)V

Where QQQ, KKK, and VVV are the Queries, Keys, and Values matrices.

Skipping over the details of how this formula is derived, QQQ represents the current token coming in that we want to predict a next value for, and KKK and VVV represent what we know about all tokens we've seen so far (including the current token).

We take the dot product of QQQ and KKK to get a similarity score between what we've seen so far (the context), and our current active token. We pipe this through a softmax to turn these into a probability distribution, and then multiply by VVV to produce a weighted combination of value vectors — tokens with higher relevance contribute more to the output.

This is a highly simplified explanation of what the attention mechanism is doing, but it hopefully gives some intuition as to why KKK and VVV are kind of a big deal. They depend on every token we've seen so far and thus scale linearly across sequence length and batch dimensions.

So what exactly are QQQ, KKK, VVV and how are they produced?<br>QQQ, KKK and VVV are the output of matrix multiplying XXX with projection weight matrices WqW_qWq​, WkW_kWk​, and WvW_vWv​.

Without a KV cache this requires the matmuls of each to have the following shape, where sactives_{\text{active}}sactive​ is the number of new tokens coming in, and spriors_{\text{prior}}sprior​ is the size of the context.

Q=X⋅Wq:[B,  sactive,  H]×[H,  nq,  dh]→[B,  sactive,  nq,  dh]K=X⋅Wk:[B,  sprior+sactive,  H]×[H,  nkv,  dh]→[B,  sprior+sactive,  nkv,  dh]V=X⋅Wv:[B,  sprior+sactive,  H]×[H,  nkv,  dh]→[B,  sprior+sactive,  nkv,  dh]\begin{aligned}<br>Q &= X \cdot W_q : [B,\; s_{\text{active}},\; H] \times [H,\; n_q,\; d_h] \to [B,\; s_{\text{active}},\; n_q,\; d_h] \\<br>K &= X \cdot W_k : [B,\; s_{\text{prior}} + s_{\text{active}},\; H] \times [H,\; n_{kv},\; d_h] \to [B,\; s_{\text{prior}} + s_{\text{active}},\; n_{kv},\; d_h] \\<br>V &= X \cdot W_v : [B,\; s_{\text{prior}} + s_{\text{active}},\; H] \times [H,\; n_{kv},\; d_h] \to [B,\; s_{\text{prior}} + s_{\text{active}},\; n_{kv},\; d_h]<br>\end{aligned}<br>QKV​=X⋅Wq​:[B,sactive​,H]×[H,nq​,dh​]→[B,sactive​,nq​,dh​]=X⋅Wk​:[B,sprior​+sactive​,H]×[H,nkv​,dh​]→[B,sprior​+sactive​,nkv​,dh​]=X⋅Wv​:[B,sprior​+sactive​,H]×[H,nkv​,dh​]→[B,sprior​+sactive​,nkv​,dh​]​

QQQ only depends on sactives_{\text{active}}sactive​ but KKK and VVV depend on both sactives_{\text{active}}sactive​ and spriors_{\text{prior}}sprior​. In practice, the Q/K/V projections are typically fused into a single matmul for efficiency. This means without KV caching, the input XXX must include all prior tokens — so we end up recomputing KKK and VVV for every prior token even though their values haven't changed since the last step, and we compute QQQ for prior tokens that we immediately discard.

This is why KV caching is such a critical compute optimization. The projection matmul for KKK and VVV now only operates over sactives_{\text{active}}sactive​ (typically 1 token during decode) instead of the full sequence. For Qwen3-0.6B at sequence 512, that's a ~512x reduction in the K/V projection: from a [512,1024]×[1024,4096][512, 1024] \times [1024, 4096][512,1024]×[1024,4096] matmul down to [1,1024]×[1024,4096][1, 1024] \times [1024, 4096][1,1024]×[1024,4096]. We simply cache the prior output projections for KKK and VVV and update the cache with each forward pass.

With KV caching our QKV projection becomes:

Q=X⋅Wq:[B,  sactive,  H]×[H,  nq,  dh]→[B,  sactive,  nq,  dh]K=X⋅Wk:[B,  sactive,  H]×[H,  nkv,  dh]→[B,  sactive,  nkv,  dh]V=X⋅Wv:[B,  sactive,  H]×[H,  nkv,  dh]→[B,  sactive,  nkv,  dh]\begin{aligned}<br>Q &= X \cdot W_q : [B,\; s_{\text{active}},\; H] \times [H,\; n_q,\; d_h] \to [B,\; s_{\text{active}},\; n_q,\; d_h] \\<br>K &= X \cdot W_k : [B,\; s_{\text{active}},\; H] \times [H,\; n_{kv},\; d_h] \to [B,\; s_{\text{active}},\; n_{kv},\; d_h] \\<br>V &= X \cdot W_v : [B,\; s_{\text{active}},\; H] \times [H,\; n_{kv},\; d_h] \to [B,\; s_{\text{active}},\; n_{kv},\;...

text sactive active prior sprior attention

Related Articles