Accelerating LLM Inference on AMD GPUs with Low-Latency GEMMs

matt_d1 pts0 comments

Accelerating LLM Inference on AMD GPUs with Low-Latency GEMMs — ROCm Blogs

Skip to main content

Back to top

Ctrl+K

ROCm blogs

Accelerating LLM Inference on AMD GPUs with Low-Latency GEMMs

Contents

Accelerating LLM Inference on AMD GPUs with Low-Latency GEMMs#

June 29, 2026 by Yutao Xu, Xiaobing Zhang, Hattie Wu, Felix Li, Lingpeng Jin, Carlus Huang, Peng Sun, Barsoum Emad.

11 min read. | 2577 total words.

Software tools & optimizations

AI/ML, Performance

User, Developers, AI

Yutao Xu, Xiaobing Zhang, Hattie Wu, Felix Li, Lingpeng Jin, Carlus Huang, Peng Sun, Barsoum Emad

English

-->

Large language model inference is becoming increasingly interactive. Users expect chatbots, coding assistants, agents, and real-time copilots to respond quickly, stream tokens smoothly, and stay responsive under concurrent load. In that setting, decode-time latency is not just a backend metric. It directly affects perceived quality.

In this blog, you will explore one small but important part of that inference path: decode-time GEMMs with small M, large N and K, BF16/FP16 inputs, optional bias, and shapes that repeat across real models . These shapes can leave conventional GEMM tiling underutilized, which makes them a useful target for decode-path optimization.

The main technique is LDS-Pipelined Split-K GEMM : the long K reduction is split across CTAs, further sliced across warp groups inside each CTA, and kept moving through a multi-stage LDS memory pipeline. On AMD GPUs, LDS means Local Data Share, the on-chip scratchpad memory used for fast cooperation inside a CTA.

You will also see how we implement this idea as an AITER FlyDSL kernel family. FlyDSL keeps low-level ROCm™ software details such as MFMA selection, LDS layout, async copies, and synchronization explicit, while still generating shape-specialized variants for the model dimensions that appear in decode. In benchmark sweeps, this targeted decode optimization reaches a 1.64x average latency improvement over the fastest of HipblasLT, AITER Triton, and AITER ASM on the K = 7168 decode grid[1], and a 1.49x average latency improvement on additional BF16 model-shape tests.

Why Does Decode Latency Matter for LLM Serving?#

LLM serving has two broad phases:

Prefill , where the model processes the prompt.

Decode , where the model generates output tokens one step at a time.

Prefill often has a larger effective M because many prompt tokens can be processed together. Decode is different. Each step may only process a small number of active tokens, especially after batching, scheduling, tensor parallelism, and request-level dynamics are taken into account.

That makes decode performance important for user-facing latency:

Time to first token affects how quickly the system appears to respond.

Time per output token affects streaming smoothness.

Inter-token latency affects whether the interaction feels fluid.

Throughput under concurrency affects how many users can be served without hurting responsiveness.

Figure 1 illustrates this interactive decode serving setting and shows where these latency concerns appear in the user-facing path.

Figure 1: Interactive LLM decode serving.

For these workloads, shaving overhead from repeated decode GEMMs can matter at the model-serving level.

Why Do Small-M, Large-K GEMMs Underperform?#

In large-model decode, GEMM often looks like:

C[M, N] = A[M, K] @ B[N, K]^T

Visually, the kernel still starts from the standard GEMM idea: compute a tile of C from a tile of A and a tile of B. The problem is that a small M produces too few output tiles, even though the K dimension can be very long.

Figure 2 shows this small-M, large-K bottleneck: the output grid is narrow, while the reduction dimension still contains substantial work.

Figure 2: Small-M, large-K GEMM bottleneck.

where M is the number of active tokens in a decode step or micro-batch. For serving workloads, M is frequently small: 1, 2, 4, 8, 16, 32, sometimes up to 128 or 256. At the same time, N and K are model-hidden-size dimensions and can be thousands or tens of thousands.

That shape regime is awkward for general GEMM libraries. A conventional large-tile GEMM wants enough M x N work per block to keep all compute units busy. Decode GEMM often does not provide that naturally. The result is under-occupancy, poor wave utilization, and too much overhead relative to useful math.

Common GEMM optimizations such as larger CTA tiles, better memory coalescing, LDS staging, MFMA-focused scheduling, and pipelining still matter. But they do not by themselves create enough independent work when the M x N output grid is small. This is why LDS-Pipelined Split-K combines multiple forms of K parallelism instead of relying on one optimization layer.

Decode GEMM Shapes in Real Models#

The motivation came directly from model shape traces, not from synthetic square GEMMs.

Across current LLMs, decode GEMM shapes repeatedly show the same pattern:

Model family

Typical decode GEMM...

decode gemm latency model small gemms

Related Articles