Deep Dive into 4-Wave Interleave FP8 GEMM

skidrow1 pts0 comments

Deep Dive Into 4-Wave Interleave FP8 GEMM — ROCm Blogs

Skip to main content

Back to top

Ctrl+K

ROCm blogs

Deep Dive Into 4-Wave Interleave FP8 GEMM

Contents

Deep Dive Into 4-Wave Interleave FP8 GEMM#

May 27, 2026 by Christian Gilli, Amanzhol Salykov, Andy Luo.

6 min read. | 1485 total words.

Software tools & optimizations

AI/ML, C++, Linear Algebra, HPC, Performance, Optimization

HPC

Christian Gilli, Amanzhol Salykov, Andy Luo

English

-->

Our previous two posts in this GEMM optimization series covered Matrix Core instructions and 8-wave ping-pong FP8 GEMM design. Here we discuss another algorithm design introduced by HipKittens - 4-wave interleave , which further improves the performance of the 8-wave ping-pong implementation. For the most complete understanding, we recommend reading this post alongside the source code.

One Wave Per SIMD#

To understand why 4-Wave Interleave exists, recall how the 8-Wave kernel worked: it placed two waves on each SIMD unit — one handling MFMA, one handling memory loads. These two waves alternate between memory and MMA instructions (“ping-pong”), and the hardware scheduler overlaps them because they use different hardware units.

4-Wave Interleave takes the opposite approach. It places one wave per SIMD , and that single wave is responsible for issuing both MFMA and memory instructions but in a very carefully hand-crafted order.

Figure 1: Wave assignment per CU#

With only one wave on the SIMD, that wave gets all 512 VGPRs compared to the 8-Wave kernel where the register budget was half that. Doubling the register file means the wave can hold a complete<br>128×128 output register tile at once, compared to 64×128 in the 8-wave case.

The 4-Wave and 8-Wave designs differ in their implementation complexity. The 8-Wave kernel requires creating alternating wave behavior via conditional __builtin_amdgcn_s_barrier(), whereas the 4-Wave employs a finer-grained software pipeline that overlaps memory and MFMA instructions.

Algorithm Breakdown#

In this section, we discuss the code and the implementation in detail.

For the 4-Wave kernel we will keep the same tile sizes as we used in the 8-Wave example from our previous article that is: 256x256x128. Also, we will only consider input matrices with dimensions that are multiples of the blocking parameters.

Figure 2 shows how the matrices are tiled with the parameters we just defined:

Figure 2: Algorithm structure#

Each wave ends up handling 4 64x64 tiles of the output matrix C, and each one of these is made up of a sequence of smaller, 16x16, sub-tiles. Using these smaller sub-tiles is the key that enables the interleaving: instead of issuing bulk MFMA instructions to compute, e.g., an entire 64x64 tile, we will issue smaller 16x16 MFMA instructions and mix them with memory instructions, creating custom software pipeline schemes. The idea is that by the time we finish processing an entire 64x64 tile of C, we will have already loaded the operands for the next one.

The code follows this structure:

// A/B tiles in LDS are 256x128 with double buffering, split in two 128x128<br>__shared__ fp8 A_lds[2][2][128 * 128];<br>__shared__ fp8 B_lds[2][2][128 * 128];

RT_C c[2][2]{};<br>RT_A a[2]{};<br>RT_B b[2]{};

// Compute on cur load on next<br>int cur = 0, next = 1;

// PROLOGUE<br>// Load a 256x128 tile of A and B<br>// load A_lds[cur] from A (global -> LDS)<br>// load B_lds[cur] from B (global -> LDS)

// Pre-load the next 256x128 tile of A and B<br>// load A_lds[next] from A (global -> LDS)<br>// load B_lds[next] from B (global -> LDS)

// Pre-load registers<br>// load a[0] from A_lds[cur][0] (LDS -> register)<br>// load b[0] from B_lds[cur][0] (LDS -> register)

// MAIN LOOP<br>for (int k = 0; k K_BLOCKS - 2; ++k) {<br>interleaved_block(<br>A_lds[cur][0], // Where to store the next 128x128 tile in LDS<br>A, // Where to load the next 128x128 tile from global memory<br>b[1], // Where to store the next 64x128 sub-tile in registers<br>B_lds[cur][1], // Where to load the next 64x128 from LDS<br>a[0], b[0], // Current MFMA operands (already in registers)<br>c[0][0] // Accumulator tile<br>);

// The remaining calls follow the same structure, rotating which buffer is being filled

interleaved_block(B_lds[cur][0], B, a[1], A_lds[cur][1], a[0], b[1], c[0][1]); // c[0][1] += a[0] * b[1]

interleaved_block(B_lds[cur][1], B, a[0], A_lds[next][0], a[1], b[0], c[1][0]); // c[1][0] += a[1] * b[0]

interleaved_block(A_lds[cur][1], A, b[0], B_lds[next][0], a[1], b[1], c[1][1]); // c[1][1] += a[1] * b[1]

// Swap cur with next<br>cur ^= 1;<br>next ^= 1;

// EPILOGUE<br>// Last two iterations (k = K_BLOCKS - 2, k = K_BLOCKS - 1) + store

The core of the kernel is interleaved_block. Each call does three things:

Issues 16 MFMA instructions (a 4x4 grid of 16x16x128 MFMA) computing a 64x64 part of C.

Issues 8 LDS → register loads (loading a 64x128 fragment of A/B) to prepare the operands for the next call.

Issues 4 global → LDS loads loading a 128x128 sub-tile of A/B that will be used in the next call to bring...

wave next load tile from instructions

Related Articles