Making FlashAttention-4 faster for inference
All posts<br>Back Engineering<br>June 11, 2026•15 minute read
Making FlashAttention-4 faster for inference<br>Charles Frye@charles_irl Member of Technical Staff
Timothy Feng Member of Technical Staff
David Wang@_dcw02 Member of Technical Staff
When the FlashAttention-4 kernel source was released last year, we dove in and shared our findings about how the kernel works in excruciating exquisite detail. You can now confirm the high-level structure we inferred by reading this post straight from the horse’s mouth.<br>In the intervening months, we’ve made a number of contributions to this kernel to make it more suitable for large language model inference and in particular for decode-heavy workloads. Unlike pre-training workloads, LLM inference workloads are often dominated by the memory bandwidth-limited “decode” or “token generation” phase (light blue, below).
Inference workloads are also generally more variable — batch sizes and sequence lengths become non-uniform; keys and values must be retrieved from cache (most of the time).<br>This requires new kernel code, and that code must be fast: “performance is the product”.<br>Before we dive into the details, some takeaways for a more general audience.<br>High-level takeaways about low-level programming<br>Our changes to extend the kernel to the inference workloads we wanted to run can be lumped into two rough categories:<br>adjusting the parallelism strategy , i.e. the number of query tiles per thread block and switching from query parallelism to key/value parallelism, and<br>supporting irregular global memory accesses , i.e. cp.async loads to replace cp.async.bulk loads using the Tensor Memory Accelerator (TMA).<br>These two categories are represented by the following figures, which are explained in detail below.<br>One of our optimizations was to port the "split KV" technique to FA4. This parallelizes work across KV tiles (right-hand side). Several of our optimizations required handling irregular memory accesses (right-hand side), which use different instructions and hardware than regular accesses (left-hand side). Adjusting parallelism strategies gives the largest leverage in improving performance on modern massively parallel hardware. Intuitively: if you are locked into a specific approach to parallelism, the sequential term in Amdahl’s Law is fixed. If you can change parallelism strategies, you can move work between the parallel and sequential components of your algorithm. This is, per the Law, generally higher leverage than increasing the speed of a fixed parallel component.<br>We didn’t choose the CUDA Templates Domain Specific Language (CuTe DSL), the original kernel authors did, but it worked well for us. It supports highly productive development loops through fast JIT compilation with minimal or zero run-time cost. It also made expressing many of our ideas more straightforward than older tools. Note that because it uses templates, FA4 is really a family of kernels, if “kernel” means roughly “something that can be launched into a CUDA stream”. We’ll keep calling it a “kernel”<br>CuTe DSL was nice. But, as we indicated in our previous post, FA4 is best understood algorithmically at the tile level, not at the warp level at which it is implemented. It’s clear that proper tile-based programming would be better for ergonomics and development speed (which, by the way, still matters in the age of agents). With a tile-based programming model, programmers can more simply express and operate on tile-level flows. That makes it easier to change or add algorithms to kernels at lower engineering cost (the first category of changes). Furthermore, higher-level tile-based models make it easier for compilers to implement and optimize, say, both cp.async and TMA load paths (the second category) and dispatch based on, say, size.<br>In this light, we’re very much looking forward to improved support for the CUDA Tile programming model, as distinct from the classic “CUDA SIMT” programming model, to build the attention and matmul kernels of the future.<br>What we did, why, and how we knew it was good<br>We organize our contributions by pull request. Each section begins with a “Figure of Merit”: the measurement used to indicate that the contribution improved performance.<br>We report these figures in the traditional format of the performance engineer: an ASCII table.<br>PR 2109: support FP8 inputs (merged April 17, 2026)<br>Figure of Merit: Up to 1.16x throughput relative to bf16 baseline<br>| Batch Size / Seq Len | BF 16 TFLOP/s | FP8 TFLOP/s | Speedup |<br>| -------------------- | ------------- | ----------- | ------- |<br>| 1 / 16384 | 1569 | 1818 | 1.13x |<br>| 32 / 512 | 962 | 1090 | 1.16x |
Training models generally requires higher precision floating point numbers to properly accumulate many small changes inside gradients. But at inference time, we can get away with lower precision. Reducing the bit width by a factor of two reduces memory and arithmetic bandwidth demand by a factor of two...