FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling - Colfax Research
Skip to content
2805 Bowers Ave, Santa Clara, CA 95051 | 408-730-2275<br>research@colfax-intl.com
Search
FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling
Ted Zadouri1,2, Markus Hoehnerbach3, Jay Shah4, Timmy Liu5, Vijay Thakkar3,6, Tri Dao1,2
1Princeton University, 2Together AI, 3Meta, 4Colfax Research, 5NVIDIA, 6Georgia Tech
Modern accelerators like Blackwell GPUs continue the trend of asymmetric hardware scaling , where tensor core throughput grows far faster than other resources such as shared memory bandwidth, special function units (SFUs) for transcendental operations like exponential, and general-purpose integer and floating-point ALUs. From the Hopper H100 to the Blackwell B200, for instance, BF16 tensor core throughput increases from 1 to 2.25 PFLOPs, while both the SFU count and shared memory bandwidth remains unchanged.
This scaling asymmetry has profound implications for optimizing complex kernels like attention for the Blackwell architecture. At its core, attention comprises two GEMMs (S=Q \cdot K^T and O=P \cdot V) with softmax in-between; in practice, it also involves substantial plumbing and bookkeeping: data movement, synchronization, layout transforms, element-wise ops, scheduling, masking, etc.
A naive viewpoint on attention might be that the speed of the GEMMs completely controls the kernel performance and one can effectively disregard these other attention components, at least to first order. However, doing a “feeds and speeds” analysis for B200 in fact shows the opposite: the main performance bottleneck lies not in how fast the tensor cores can do MMA, but rather (a) in the SFU units for softmax exponential during the FWD computation, and (b) in the shared-memory traffic during the BWD computation.
In this blog post, we present FlashAttention-4 , an algorithm and kernel co-design that maximizes overlap between matmul and these other resource bottlenecks. On B200 with BF16, it reaches up to 1605 TFLOPs/s (71% utilization), up to 1.3× faster than cuDNN version 9.13 and 2.7× faster than Triton.
Our main algorithmic and kernel co-design ideas are as follows:
New pipelining for maximum overlap : New forward and backward software pipelines that exploit Blackwell fully asynchronous MMA and larger tile sizes, overlapping tensor cores, softmax exponential, and memory operations.
Forward (FWD) pass : A software emulation of the exponential function implemented via polynomial approximation on FMA units to mitigate the exponential bottleneck, plus conditional online softmax rescaling.
Backward (BWD) pass : Storing intermediate results in tensor memory to relieve shared-memory traffic, combined with Blackwell’s new 2-CTA MMA mode to reduce shared memory traffic further and also cut atomic reduction in half, and additional support for deterministic execution mode for reproducible training.
Scheduling : New tile scheduler to mitigate load imbalance from causal mask and variable sequence length.
FlashAttention-4 is available at: https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute.
arXiv: https://arxiv.org/abs/2603.05451
FA4_Blackwell
New hardware features on Blackwell
Tensor memory (TMEM) : On B200, each of the 148 SMs has 256 KB of TMEM, an on chip scratchpad wired into the tensor cores for warp synchronous intermediate storage.
Fully asynchronous 5th gen tensor cores : tcgen05.mma is asynchronous and accumulates in TMEM. For BF16 and FP16, the largest single CTA UMMA tile is 128×256×16, which is about 2× larger than the largest Hopper WGMMA atom. UMMA is launched by a single thread, easing register pressure and making larger tiles and deeper pipelines practical without the spilling pain points of Hopper warpgroup MMA. This also makes warp specialization more viable, with some warps moving tiles while others issue MMA to overlap matrix multiply accumulate with softmax and memory traffic. tcgen05.mma can also source operand A from TMEM.
2-CTA MMA : Blackwell can execute one UMMA across a CTA pair in the same cluster, spanning the TMEM of both peer CTAs. One thread in the leader CTA launches the MMA, but both CTAs must stay active while it is in flight. This scales the MMA tile dimension up to 256×256×16 by splitting M and N across the pair, reducing redundant traffic and lowering per CTA footprint. The CTA group size, 1 or 2, must remain constant across TMEM and tensor core operations within a kernel.
Feeds and Speeds
For M=N=D=128, here are the feeds on B200 (per SM):
Tensor Cores (BF16) : \frac{8192 \text{ ops}}{cycle}
Exponential unit : \frac{16 \text{ ops}}{cycle}
Shared Memory traffic : \frac{128 \text{ bytes}}{cycle}
And the speeds (clock-cycles per tile):
Forward (2 MMAs + MN exp) :
Tensor Cores: 1024
Exp: 1024
SMEM: 768
Backward (5 MMAs + MN exp) — 1-CTA :
Tensor Cores: 2560
Exp:...