FP8 GEMM Optimization on AMD CDNA™4 Architecture — ROCm Blogs
Skip to main content
Back to top
Ctrl+K
ROCm blogs
FP8 GEMM Optimization on AMD CDNA™4 Architecture
Contents
FP8 GEMM Optimization on AMD CDNA™4 Architecture#
March 10, 2026 by Jiahui Cao, Amanzhol Salykov, Andy Luo.
8 min read. | 2063 total words.
Software tools & optimizations
AI/ML, C++, Linear Algebra, Performance, HPC, Optimization, Hardware
HPC
Jiahui Cao, Amanzhol Salykov, Andy Luo
English
-->
This blog post continues our previous blog Matrix Core Programming on AMD CDNA™3 and CDNA™4 Architecture, which introduced Matrix Cores and demonstrated how to use them in HIP kernels.
In this post, we take the next step by showing how to use Matrix Cores in GEMM kernels, with a particular focus on optimizing an FP8 GEMM kernel on the AMD Instinct™ MI355X GPUs. If you are not yet familiar with Matrix Cores, we recommend reading the introductory post first and then returning to this article.
GPU Characteristics#
Compared with CDNA™3 architecture, the CDNA™4 architecture increases LDS capacity and read bandwidth (160 KB, 256 B/clk), expands the per-lane GLOBAL_LOAD_LDS transfer width (128-bit vs. 32-bit), and adds broader low-precision matrix-core support, including FP4/FP6 dense matrix fused-multiply-add (MFMA) and block-scaled MFMA instructions. Table 1 summarizes the architectural differences that matter most for this GEMM kernel design.
Feature
CDNA™4
CDNA™3
Wavefront size
64 (wave64)
64 (wave64)
LDS capacity
160 KB per CU/workgroup addressable LDS
64 KB LDS
LDS bank count
64 banks
32 banks
LDS read bandwidth
256 bytes/clock
128 bytes/clock
GLOBAL_LOAD_LDS per-lane transfer
Up to 128 bits/lane
Up to 32 bits/lane
FP4/FP6 MFMA
Supported
Not supported
Block-scaled MFMA
Adds V_MFMA_SCALE_F32_16X16X128_F8F6F4 and V_MFMA_SCALE_F32_32X32X64_F8F6F4
Not supported
FP16/BF16 MFMA shapes
Adds larger shapes (16x16x32, 32x32x16) in addition to CDNA™3
Up to 16x16x16 and 32x32x8
Table 1. Architectural differences between AMD CDNA™4 and CDNA™3 relevant to FP8 GEMM kernels.
Source data: AMD Instinct CDNA™4 ISA and AMD Instinct MI300 CDNA™3 ISA.
FP8 GEMM#
In this blog post, we will implement a GEMM kernel that computes \(C=A B^T\). The kernel multiplies matrix A of shape MxK with the transpose of matrix B which has shape NxK. The result is then written to matrix C of shape MxN. The input matrices are stored in row-major order and have FP8 (E4M3FN) data type. The output matrix has BF16 data type and is stored in row-major order as well. To minimize numerical accuracy loss during the computation, the accumulation is performed in FP32 precision.
To calculate the achieved FLOP/s, we use the following formula, given known kernel duration in seconds \(t\).
\[<br>\mathrm{FLOPs} = 2 M N K<br>\]
\[<br>\mathrm{TFLOP/s} = \frac{2 M N K}{t} \cdot 10^{-12}<br>\]
hipBLASLt Benchmark#
We use hipBLASLt as our performance target. The hipblaslt-bench script allows us to benchmark hipBLASLt FP8 GEMM for a specific matrix problem size using rotating buffers and warm-up iterations. For example, to benchmark hipBLASLt on matrix problem size M=N=K=4096:
hipblaslt-bench --api_method c --stride_a 0 --stride_b 0 --stride_c 0 --stride_d 0 \<br>--alpha 1 --beta 0 --transA T --transB N --batch_count 1 --scaleA 1 --scaleB 1 \<br>--a_type f8_r --b_type f8_r --c_type bf16_r --d_type bf16_r \<br>--scale_type f32_r --bias_type f32_r --compute_type f32_r --rotating 512 \<br>--iters 1000 --cold_iters 1000 -m 4096 -n 4096 -k 4096 \<br>--lda 4096 --ldb 4096 --ldc 4096 --ldd 4096
Which gives ~2750 TFLOPS/s on the AMD MI355X. For M=N=K=8192, hipBLASLt achieves ~3130 TFLOPS/s[1]. Please refer to hipblaslt-bench for more information about the command line interface and available options.
Naive FP8 GEMM#
We start with the simplest version as a baseline. First, recall the GEMM form used here:
\[<br>C = A \cdot B^T,\quad C_{i,j} = \sum_{k=0}^{K-1} A_{i,k} \cdot B_{j,k}<br>\]
From this equation, the most direct mapping is: one thread computes one output element C[row, col]. That thread loops over k and accumulates the dot product. This is easy to implement, but it reloads A and B from global memory repeatedly, so it is expected to be memory-bound. The measured baseline result is 1.15 TFLOPS/s for M=N=K=4096.
Baseline code example:
__global__ void baseline_fp8_gemm_kernel(const fp8e4m3* A,<br>const fp8e4m3* B,<br>bf16* C,<br>int M,<br>int N,<br>int K,<br>int lda,<br>int ldb,<br>int ldc,<br>float alpha,<br>float beta) {<br>const int row = blockIdx.y * blockDim.y + threadIdx.y;<br>const int col = blockIdx.x * blockDim.x + threadIdx.x;<br>if (row >= M || col >= N) {<br>return;
// FP32 accumulation for one dot product.<br>float acc = 0.0f;<br>for (int k = 0; k K; ++k) {<br>acc += float(A[row * lda + k]) * float(B[col * ldb + k]);<br>const float c_prev = (beta == 0.0f) ? 0.0f : static_castfloat>(C[row * ldc + col]);<br>//Write back in BF16.<br>C[row * ldc + col] = bf16(alpha * acc + beta * c_prev);
LDS Tiling to Improve Data Reuse#
In the...