When does fragmentation occur in the CUDA caching allocator? — PyTorch DevLogWhen does fragmentation occur in the CUDA caching allocator?<br>Edward Yang (@ezyang) ·<br>June 1, 2026<br>· 12 min read<br>eagercudamemory<br>Disclosure. This post was drafted by Claude (Anthropic’s coding<br>assistant) with editing from ezyang.
In an ideal world, users of CUDA memory in PyTorch programs should be able to<br>abstract the allocator behavior as: there is a fixed amount of GPU memory,<br>whenever you allocate this available memory goes down, and when you free the<br>available memory goes back up.<br>Unfortunately, the internal implementation of the CUDA caching allocator means<br>that certain allocation patterns can give rise to fragmentation, where<br>even though there is “technically” enough free space to store a requested<br>allocation, the CUDA caching allocator is unable to actually serve the request.<br>There are many modern use cases where users wish to use as much memory that<br>their GPUs provide as possible, while needing to ensure we do not OOM. Users<br>are often penny-inching allocations in this situation, and find it very<br>surprising when PyTorch reserves more memory than they expect under the<br>abstract model of the allocator.<br>This is especially common in LLM serving, where every megabyte of GPU memory<br>that isn’t nailed down by model weights or CUDA graph buffers can be used for<br>KV cache. Modern disaggregated serving involves CUDA graphing distinct graphs<br>for each batch size. It’s important for these graphs to share the same memory<br>pool. But sharing a pool means the allocator’s internal bookkeeping needs to<br>be correct before each recording. And the way the allocator manages<br>memory–splitting and merging blocks–can go wrong in ways that depend on<br>allocation order.<br>In this post, we’ll walk through some small laboratory examples where this<br>fragmentation happens, and then demonstrate why expandable segments fixes<br>these examples. It’s important to have a mental model for what exactly we<br>mean by “fragmentation”, because some fragmentation can be solved with<br>expandable segments (especially those related to recording CUDA graphs), while<br>others cannot.<br>Segments, blocks, and splitting<br>The caching allocator organizes GPU memory in two levels. Segments<br>are contiguous regions obtained from CUDA (cudaMalloc or virtual memory<br>mapping). Blocks are sub-regions within a segment that track<br>individual allocations.<br>When a request comes in, the allocator finds a free block that’s large<br>enough. If the block is bigger than needed, it splits the block: the<br>front portion serves the allocation, the back portion becomes a new free<br>block. When a block is freed, the allocator tries to merge it with<br>its immediate neighbors–but only if the neighbor is also free. Two free<br>blocks separated by an allocated block cannot merge.<br>import gc, torch
MiB = 1024 * 1024
def alloc(n, mib, pool, dev):<br>with torch.cuda.use_mem_pool(pool, dev):<br>return [<br>torch.empty(int(mib * MiB), dtype=torch.uint8, device=dev)<br>for _ in range(n)
def free(ts):<br>ts.clear()
def layout(pool):<br>for s in torch.cuda.memory_snapshot(pool.id):<br>blocks = " | ".join(f"{b['size']//MiB}M {b['state']}" for b in s["blocks"])<br>print(f" seg {s['total_size']//MiB}M: [{blocks}]")
pool = torch.cuda.MemPool()<br>dev = torch.device("cuda:0")
t = alloc(1, 32, pool, dev)<br>layout(pool) # one 32M block
free(t)
ts = alloc(2, 16, pool, dev)<br>layout(pool) # 32M segment split into two 16M blocks
del ts[0]<br>layout(pool) # first block inactive, second still active; can't merge
free(ts)<br>layout(pool) # both free and adjacent; merged back to 32M
How segments are obtained depends on whether expandable segments are<br>enabled. The behavior is quite different in each case.<br>Without expandable segments<br>Run scripts in this section with<br>PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False.<br>Without expandable segments, each cudaMalloc call creates a separate<br>segment. For allocations between 1 MiB and 10 MiB, the allocator<br>requests a 20 MiB segment regardless of the actual size. For allocations<br>= 10 MiB, the segment is rounded up to the nearest 2 MiB.
The key constraint: blocks in different segments can never merge .<br>Each cudaMalloc is an independent allocation from CUDA’s perspective.<br>A free 16 MiB block in one segment cannot combine with a free 16 MiB<br>block in another segment to serve a 32 MiB request.<br>This is where allocation order matters. Let’s walk through two<br>scenarios step by step.<br>Small then large (bad order):<br>import gc, torch
MiB = 1024 * 1024
def alloc(n, mib, pool, dev):<br>with torch.cuda.use_mem_pool(pool, dev):<br>return [<br>torch.empty(int(mib * MiB), dtype=torch.uint8, device=dev)<br>for _ in range(n)
def free(ts):<br>ts.clear()
def reserved(pool):<br>return sum(s["total_size"] for s in torch.cuda.memory_snapshot(pool.id))
def layout(pool):<br>for s in torch.cuda.memory_snapshot(pool.id):<br>blocks = " | ".join(f"{b['size']//MiB}M {b['state']}" for b in s["blocks"])<br>print(f" seg...