Lighthouse Attention - NOUS RESEARCH
Skip to content
","library":"fa-solid"},"layout":"horizontal","toggle":"burger"}" data-widget_type="nav-menu.default">
Menu
Lighthouse Attention
Lighthouse Attention - Nous Research
Lighthouse Attention
by<br>Nous Research
arXiv<br>2605.06554
TL;DR. A selection-based hierarchical attention that runs the same<br>forward+backward pass ~17× faster than standard attention at 512K context<br>on a single B200, and delivers a 1.4–1.7× end-to-end pretraining<br>speedup at 98K context. Q, K, V are pooled symmetrically across an L-level pyramid;<br>per-head $\ell_2$ norms pick a small dense sub-sequence; FlashAttention runs on the gather<br>— no custom sparse attention kernel, no straight-through estimator, no auxiliary loss. After the<br>sparse stage, a brief standard attention resume converts the checkpoint back into a<br>dense attention model: every recovered run matches or beats dense-from-scratch at the same<br>token budget. Validated at 530M Llama-3, 16k optimiser steps, 50B tokens, with 1M-token<br>training across 32 B200s under context parallelism.
Long-context pretraining is bottlenecked by attention's quadratic compute cost.<br>FlashAttention shaves the constants, but the wall is still there: you train at the contexts<br>you can afford.
We introduce Lighthouse Attention , a selection-based hierarchical<br>attention that pools queries, keys and values symmetrically across a multi-resolution<br>pyramid, scores every pyramid entry with a parameter-free function, and keeps the selection<br>logic outside the attention kernel. The expensive step in the forward pass is<br>FlashAttention on a small dense sub-sequence. The same kernel runs at training and inference,<br>and we inherit every upstream FlashAttention improvement unchanged.
The code is at<br>github.com/ighoshsubho/lighthouse-attention.
Two design decisions
Most prior work in this space (NSA, HISA, InfLLM-v2, DSA, MoBA) makes two design decisions<br>that quietly matter for training.
Asymmetry. Queries stay at full resolution; only keys and values are pooled.<br>The hierarchy serves as a compressed addressable memory rather than a multi-scale<br>representation.
Architectural entanglement. Selection lives inside the attention kernel. The<br>carefully optimised dense attention kernels that modern tensor cores accelerate can't be<br>reused; every sparse method ships its own kernel.
There is also a concern specific to training. An inference-time sparse method is at<br>most as good as its dense backbone: the sparse substitution is evaluated only against the<br>dense forward. A training-time sparse method has to survive a harder test: once<br>training is done, will the model still be a competent dense-attention model?<br>If not, it has just trained a specialist of its own approximation.
We treat that question as the central correctness check.
The method
Symmetric pooling. Q, K and V all get pooled by the same factor at every<br>level of the hierarchy. A pooled query at level $\ell$ lives in the same representation space<br>as a pooled key at level $\ell$. This is the choice that turns the dense-attention call from<br>$O(N \cdot S \cdot d)$ to $O(S^2 \cdot d)$ at training time.
Parameter-free scoring. Each pyramid entry gets two scalar scores: the<br>$\ell_2$ norm of its query projection, and the $\ell_2$ norm of its key projection. There is<br>no learned scorer head, no auxiliary loss, no Gumbel-softmax, no straight-through estimator.<br>The projections are encouraged to be useful when selected, not to score well at<br>selecting. A dilated softmax-attention scorer is a strictly stronger signal — it<br>sees QK interactions, the norm scorer doesn't — so our results are a lower bound on what<br>selection-based training can give.
Selection outside the kernel. Once top-K is decided, we gather the chosen<br>entries into a contiguous, causally-sorted dense sub-sequence and run FlashAttention on it. The expensive step at training time is the same dense-attention<br>kernel the dense baseline uses; forward and backward are bit-for-bit identical to a dense<br>Transformer's.
The four stages
A Lighthouse attention layer replaces standard scaled dot-product attention with four stages<br>that surround, but do not modify, the attention kernel.
Figure 1. Lighthouse Attention. Forward (black) projects $H_t$ into Q, K, V,<br>applies the symmetric Pyramid Pool, and (guided by indices $\mathcal{I}$ from the<br>Hierarchical Selector: Score → Top-K) feeds a dense gather, FlashAttention,<br>and a deterministic scatter-back to produce $O_t$. The selector branch is<br>non-differentiable: top-K returns integer indices, so no gradient flows through Score<br>or Top-K.
Three small interactive panels make each stage concrete.
(i) Pyramid pool
Average-pool Q, K, V symmetrically into an L-level pyramid with pooling factor $p$:
$$ Q^{(\ell)} = \mathrm{Pool}_{\mu}(Q), \quad K^{(\ell)} = \mathrm{Pool}_{\mu}(K), \quad V^{(\ell)} = \mathrm{Pool}_{\mu}(V), \quad \ell = 0, 1, \ldots, L-1 $$
Level 0 is the full sequence; level $\ell$ has $N/p^{\ell}$...