FP8 GEMM Optimization on AMD CDNA4 Architecture

skidrow1 pts0 comments

FP8 GEMM Optimization on AMD CDNA™4 Architecture — ROCm Blogs

Skip to main content

Back to top

Ctrl+K

ROCm blogs

FP8 GEMM Optimization on AMD CDNA™4 Architecture

Contents

FP8 GEMM Optimization on AMD CDNA™4 Architecture#

March 10, 2026 by Jiahui Cao, Amanzhol Salykov, Andy Luo.

8 min read. | 2063 total words.

Software tools & optimizations

AI/ML, C++, Linear Algebra, Performance, HPC, Optimization, Hardware

HPC

Jiahui Cao, Amanzhol Salykov, Andy Luo

English

-->

This blog post continues our previous blog Matrix Core Programming on AMD CDNA™3 and CDNA™4 Architecture, which introduced Matrix Cores and demonstrated how to use them in HIP kernels.

In this post, we take the next step by showing how to use Matrix Cores in GEMM kernels, with a particular focus on optimizing an FP8 GEMM kernel on the AMD Instinct™ MI355X GPUs. If you are not yet familiar with Matrix Cores, we recommend reading the introductory post first and then returning to this article.

GPU Characteristics#

Compared with CDNA™3 architecture, the CDNA™4 architecture increases LDS capacity and read bandwidth (160 KB, 256 B/clk), expands the per-lane GLOBAL_LOAD_LDS transfer width (128-bit vs. 32-bit), and adds broader low-precision matrix-core support, including FP4/FP6 dense matrix fused-multiply-add (MFMA) and block-scaled MFMA instructions. Table 1 summarizes the architectural differences that matter most for this GEMM kernel design.

Feature

CDNA™4

CDNA™3

Wavefront size

64 (wave64)

64 (wave64)

LDS capacity

160 KB per CU/workgroup addressable LDS

64 KB LDS

LDS bank count

64 banks

32 banks

LDS read bandwidth

256 bytes/clock

128 bytes/clock

GLOBAL_LOAD_LDS per-lane transfer

Up to 128 bits/lane

Up to 32 bits/lane

FP4/FP6 MFMA

Supported

Not supported

Block-scaled MFMA

Adds V_MFMA_SCALE_F32_16X16X128_F8F6F4 and V_MFMA_SCALE_F32_32X32X64_F8F6F4

Not supported

FP16/BF16 MFMA shapes

Adds larger shapes (16x16x32, 32x32x16) in addition to CDNA™3

Up to 16x16x16 and 32x32x8

Table 1. Architectural differences between AMD CDNA™4 and CDNA™3 relevant to FP8 GEMM kernels.

Source data: AMD Instinct CDNA™4 ISA and AMD Instinct MI300 CDNA™3 ISA.

FP8 GEMM#

In this blog post, we will implement a GEMM kernel that computes \(C=A B^T\). The kernel multiplies matrix A of shape MxK with the transpose of matrix B which has shape NxK. The result is then written to matrix C of shape MxN. The input matrices are stored in row-major order and have FP8 (E4M3FN) data type. The output matrix has BF16 data type and is stored in row-major order as well. To minimize numerical accuracy loss during the computation, the accumulation is performed in FP32 precision.

To calculate the achieved FLOP/s, we use the following formula, given known kernel duration in seconds \(t\).

\[<br>\mathrm{FLOPs} = 2 M N K<br>\]

\[<br>\mathrm{TFLOP/s} = \frac{2 M N K}{t} \cdot 10^{-12}<br>\]

hipBLASLt Benchmark#

We use hipBLASLt as our performance target. The hipblaslt-bench script allows us to benchmark hipBLASLt FP8 GEMM for a specific matrix problem size using rotating buffers and warm-up iterations. For example, to benchmark hipBLASLt on matrix problem size M=N=K=4096:

hipblaslt-bench --api_method c --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 \<br>--alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 1 --scaleB 1 \<br>--a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r \<br>--scale_type f32_r --bias_type f32_r --compute_type f32_r --rotating 512 \<br>--iters 1000 --cold_iters 1000 -m 4096 -n 4096 -k 4096 \<br>--lda 4096 --ldb 4096 --ldc 4096 --ldd 4096

Which gives ~2750 TFLOPS/s on the AMD MI355X. For M=N=K=8192, hipBLASLt achieves ~3130 TFLOPS/s[1]. Please refer to hipblaslt-bench for more information about the command line interface and available options.

Naive FP8 GEMM#

We start with the simplest version as a baseline. First, recall the GEMM form used here:

\[<br>C = A \cdot B^T,\quad C_{i,j} = \sum_{k=0}^{K-1} A_{i,k} \cdot B_{j,k}<br>\]

From this equation, the most direct mapping is: one thread computes one output element C[row, col]. That thread loops over k and accumulates the dot product. This is easy to implement, but it reloads A and B from global memory repeatedly, so it is expected to be memory-bound. The measured baseline result is 1.15 TFLOPS/s for M=N=K=4096.

Baseline code example:

__global__ void baseline_fp8_gemm_kernel(const fp8e4m3* A,<br>const fp8e4m3* B,<br>bf16* C,<br>int M,<br>int N,<br>int K,<br>int lda,<br>int ldb,<br>int ldc,<br>float alpha,<br>float beta) {<br>const int row = blockIdx.y * blockDim.y + threadIdx.y;<br>const int col = blockIdx.x * blockDim.x + threadIdx.x;<br>if (row >= M || col >= N) {<br>return;

// FP32 accumulation for one dot product.<br>float acc = 0.0f;<br>for (int k = 0; k K; ++k) {<br>acc += float(A[row * lda + k]) * float(B[col * ldb + k]);<br>const float c_prev = (beta == 0.0f) ? 0.0f : static_castfloat>(C[row * ldc + col]);<br>//Write back in BF16.<br>C[row * ldc + col] = bf16(alpha * acc + beta * c_prev);

LDS Tiling to Improve Data Reuse#

In the...

cdna gemm matrix hipblaslt architecture float

Related Articles