Modern GPU Programming For MLSys — Modern GPU Programming For MLSys
Skip to main content
Back to top
Ctrl+K
Search<br>Ctrl+K
Modern GPU Programming For MLSys
Contents
Modern GPU Programming For MLSys#
Machine learning systems sit at the heart of modern AI workloads. In these systems, performance<br>often comes down to the quality of a small number of GPU kernels. Attention kernels, LLM prefill<br>and decode kernels, low-precision block-scaled GEMMs, fused MoE layers, and other large fused<br>kernels all directly shape end-to-end speed in both training and serving.
To make these kernels fast, however, we need more than a list of optimization tricks. Modern GPUs<br>are no longer simple variations of the same old design. Recent architectures introduce richer<br>memory spaces, new access patterns, and increasingly specialized execution units. To program them<br>well, we need both a clear mental model of the hardware and a practical understanding of how<br>high-performance kernels are built. This book is about developing both.
The book follows a simple progression: first understand the GPU hardware, then learn the<br>programming model we will use, and finally build state-of-the-art kernels step by step. Our main<br>target is the Blackwell generation, and our main running examples are fast matrix multiplication<br>(GEMM) and FlashAttention. Along the way, we will also study the core ingredients behind GPU<br>optimization: data layout, asynchronous data movement, and asynchronous coordination.
The material grows out of the Machine Learning Systems course series<br>at Carnegie Mellon University. To make the ideas easier to study and easier to run, this book uses<br>the TIRx Python DSL to build real GPU kernel examples step by step. TIRx stays close to the<br>hardware, which lets us reason about low-level control while still learning through runnable code.
How This Book Is Organized#
Part I, Understanding the GPU. This part introduces the overall organization of the GPU,<br>general recipes for writing fast kernels, and key concepts such as data layout, asynchronous<br>memory operations, and coordination. It builds the hardware intuition that the rest of the book<br>relies on.
Part II, TIRx Overview. This part introduces the key elements of TIRx, which serve as the<br>foundation for the code examples throughout the book.
Part III, GEMM: Tiled to SOTA. A complete guide to optimizing a tiled GEMM, built up through<br>TMA pipelining, persistent scheduling, warp specialization, and 2-CTA clusters.
Part IV, Flash Attention 4. A complete attention kernel built from the Part III techniques:<br>two MMAs with softmax between them, online-softmax rescaling, causal masking, and GQA.
Reference. TIRx language reference and compiler internals.
Part I, Understanding the GPU
GPU Execution Model
What Makes a Kernel Fast
Data Layout and Its Notation
Tensor Core Operand Layouts Across GPU Generations
Async Data Movement: TMA
Tensor Cores: tcgen05
Special Memory: TMEM
Async Coordination: mbarriers
Advanced: Cluster Launch Control
Part II, TIRx Overview
Introduction to TIRx
TIRx Layout API
Part III, GEMM: Tiled to SOTA
Building a Tiled GEMM<br>GEMM
Optimization Path
Step 1: Sequential Single-Tile GEMM
Step 2: K-Loop Accumulation
Step 3: Spatial Tiling (Multi-CTA)
Exercises
Pipelining GEMM with TMA<br>Step 4: TMA Async Load
Step 5: Software Pipeline (PIPE_DEPTH=2)
Step 6: Persistent Kernel + Tile Scheduler
Exercises
Scaling GEMM with Warp Specialization and Clusters<br>Step 7: Warp Specialization + Pipeline
Step 8: 2-CTA Cluster
Step 9: Multi-Consumer Warp Specialization
End-to-End Result
Exercises
Part IV, Flash Attention 4
Flash Attention 4<br>Algorithm Shape
Tile-Primitive Graph
Warp Roles and Scopes
Reading the Fragments
The Two MMA Phases
TMEM Layout and Reuse
How Barriers Connect the Roles
Pipelining Structure
Rescaling and Writeback
Causal Masking
GQA Support
Tile Scheduling
Compile and Verify
Differences from GEMM
Exercises
Reference
Reference
Debugging Warp-Specialized Kernels
Compiler Internals
TIRx Language Reference
Contents