Accelerating LLM Inference on AMD GPUs with Low-Latency GEMMs — ROCm Blogs
Skip to main content
Back to top
Ctrl+K
ROCm blogs
Accelerating LLM Inference on AMD GPUs with Low-Latency GEMMs
Contents
Accelerating LLM Inference on AMD GPUs with Low-Latency GEMMs#
June 29, 2026 by Yutao Xu, Xiaobing Zhang, Hattie Wu, Felix Li, Lingpeng Jin, Carlus Huang, Peng Sun, Barsoum Emad.
11 min read. | 2577 total words.
Software tools & optimizations
AI/ML, Performance
User, Developers, AI
Yutao Xu, Xiaobing Zhang, Hattie Wu, Felix Li, Lingpeng Jin, Carlus Huang, Peng Sun, Barsoum Emad
English
-->
Large language model inference is becoming increasingly interactive. Users expect chatbots, coding assistants, agents, and real-time copilots to respond quickly, stream tokens smoothly, and stay responsive under concurrent load. In that setting, decode-time latency is not just a backend metric. It directly affects perceived quality.
In this blog, you will explore one small but important part of that inference path: decode-time GEMMs with small M, large N and K, BF16/FP16 inputs, optional bias, and shapes that repeat across real models . These shapes can leave conventional GEMM tiling underutilized, which makes them a useful target for decode-path optimization.
The main technique is LDS-Pipelined Split-K GEMM : the long K reduction is split across CTAs, further sliced across warp groups inside each CTA, and kept moving through a multi-stage LDS memory pipeline. On AMD GPUs, LDS means Local Data Share, the on-chip scratchpad memory used for fast cooperation inside a CTA.
You will also see how we implement this idea as an AITER FlyDSL kernel family. FlyDSL keeps low-level ROCm™ software details such as MFMA selection, LDS layout, async copies, and synchronization explicit, while still generating shape-specialized variants for the model dimensions that appear in decode. In benchmark sweeps, this targeted decode optimization reaches a 1.64x average latency improvement over the fastest of HipblasLT, AITER Triton, and AITER ASM on the K = 7168 decode grid[1], and a 1.49x average latency improvement on additional BF16 model-shape tests.
Why Does Decode Latency Matter for LLM Serving?#
LLM serving has two broad phases:
Prefill , where the model processes the prompt.
Decode , where the model generates output tokens one step at a time.
Prefill often has a larger effective M because many prompt tokens can be processed together. Decode is different. Each step may only process a small number of active tokens, especially after batching, scheduling, tensor parallelism, and request-level dynamics are taken into account.
That makes decode performance important for user-facing latency:
Time to first token affects how quickly the system appears to respond.
Time per output token affects streaming smoothness.
Inter-token latency affects whether the interaction feels fluid.
Throughput under concurrency affects how many users can be served without hurting responsiveness.
Figure 1 illustrates this interactive decode serving setting and shows where these latency concerns appear in the user-facing path.
Figure 1: Interactive LLM decode serving.
For these workloads, shaving overhead from repeated decode GEMMs can matter at the model-serving level.
Why Do Small-M, Large-K GEMMs Underperform?#
In large-model decode, GEMM often looks like:
C[M, N] = A[M, K] @ B[N, K]^T
Visually, the kernel still starts from the standard GEMM idea: compute a tile of C from a tile of A and a tile of B. The problem is that a small M produces too few output tiles, even though the K dimension can be very long.
Figure 2 shows this small-M, large-K bottleneck: the output grid is narrow, while the reduction dimension still contains substantial work.
Figure 2: Small-M, large-K GEMM bottleneck.
where M is the number of active tokens in a decode step or micro-batch. For serving workloads, M is frequently small: 1, 2, 4, 8, 16, 32, sometimes up to 128 or 256. At the same time, N and K are model-hidden-size dimensions and can be thousands or tens of thousands.
That shape regime is awkward for general GEMM libraries. A conventional large-tile GEMM wants enough M x N work per block to keep all compute units busy. Decode GEMM often does not provide that naturally. The result is under-occupancy, poor wave utilization, and too much overhead relative to useful math.
Common GEMM optimizations such as larger CTA tiles, better memory coalescing, LDS staging, MFMA-focused scheduling, and pipelining still matter. But they do not by themselves create enough independent work when the M x N output grid is small. This is why LDS-Pipelined Split-K combines multiple forms of K parallelism instead of relying on one optimization layer.
Decode GEMM Shapes in Real Models#
The motivation came directly from model shape traces, not from synthetic square GEMMs.
Across current LLMs, decode GEMM shapes repeatedly show the same pattern:
Model family
Typical decode GEMM...