Delayed Tensor Parallelism for Faster Transformer Inference
Subscribe
Modern LLM inference is increasingly shaped by latency-critical workloads: multi-step agentic workflows, real-time copilots, voice assistants, and reasoning systems that generate long chains of thought. In these settings, batch-size-one token generation speed, not aggregated throughput, is the metric that matters. But in this setting, decoding latency becomes dominated by memory movement and synchronization overhead rather than raw compute.
In this post, we introduce Delayed Tensor Parallelism (DTP) , a new architecture designed to hide communication behind computation and weight streaming. We show that DTP preserves the quality of standard tensor-parallel Transformer architectures while dramatically reducing exposed communication costs, enabling much faster inference on modern AMD and NVIDIA GPUs.
Introduction
LLMs are typically optimized for throughput: serving many users at once with large batch sizes to maximize hardware utilization. But not all applications live in that regime.<br>Typically for applications such as voice assistants, real-time copilots, reasoning models and agentic workflows, what matters for users is latency at batch size one. In this regime, the bottlenecks shift: performance is no longer compute-bound, but dominated by weight streaming, kernel launch overheads, and memory movement.
A natural way to reduce these costs is to shard the model across multiple GPU devices. In practice, this is done with Tensor Parallelism [1] (TP), which splits the computation of attention and MLP layers across GPUs.
But TP is not a free lunch. It introduces communication overhead that can wipe out its benefits. This becomes especially painful when every other bottleneck such as weight streaming continuity and kernel granularity are already super-optimized [Monokernel].
A fairly natural way to alleviate this communication overhead is simply to parallelize the model in the way TP does while completely removing communications. However, we show that training from scratch with such no-communication architecture variant heavily degrades performance.<br>To claw back the performance gap introduced by communication removal, the Kog Team proposes Delayed Tensor Parallelism (DTP). DTP is an architectural variant of the base Transformer model that allows the TP scheme to overlap communication and computation. As a result, training a LLM with the DTP architecture gets the best of both worlds, meaning that it alleviates communication overhead while keeping performance in the same ballpark.<br>We show the former points experimentally: pretraining with our architecture variant instead of the usual Transformer blocks claws back quality w.r.t to the version without communication. In fact, performance-wise, DTP stands very close to the vanilla Transformer blocks.<br>Furthermore, in our batch-size one target setup, we compare DTP with state-of-the-art methods for communication overhead reduction and we show that, when aiming at no communication overhead, DTP can significantly outperform them.<br>Strong with those findings we pretrained a 2B-parameter model with DTP. This model includes all Kog's GPU team optimizations plus our DTP innovative architecture and achieves unprecedented speed on AMD and NVIDIA datacenter GPUs.
Background on Tensor Parallelism
Tensor Parallelism [1] (TP) is usually the go-to technique when sharding a Transformer architecture across several GPU devices, especially when the model at hand fits into the aggregated memory of the GPUs on a single node.
In TP, weights are sharded across devices on a per-module basis, where a module refers to either an MLP or attention block. For a given module, each device performs its partial forward pass using its local parameter shard, then the partial outputs are aggregated via an all-reduce operation to recover the result equivalent to single-device execution.
More formally, for a Transformer with hidden size \(d\), let \(\mathbf{X}^{(n)} \in \mathbb{R}^{S \times d}\) denote an input of sequence length \(S\) to the \(n\)-th Transformer module \(\mathbf{M}^{(n)}(\cdot\,; \theta)\) with parameters \(\theta\). Under TP with \(L\) devices, the parameter set is partitioned into disjoint shards \(\theta_l\), and the module output is computed as:
\[\mathbf{X}^{(n+1)} = \mathbf{X}^{(n)} + \sum_{l=1}^{L} \mathbf{o}_l^{(n)},\]
where each term \(\mathbf{o}_l^{(n)} = \mathbf{M}^{(n)}(\mathbf{X}^{(n)};\theta^{(n)}_l)\) is the local output of module \(n\) computed on device \(l\). Figure 1 summarizes the computational graph of a TP layer.
Figure 1: Tensor Parallelism computational graph
By enabling parallel streaming of parameter shards, TP can significantly reduce inference latency in memory-bound settings. However, it introduces a synchronization cost: each Transformer module requires an all-reduce operation to ensure equivalence with the non-partitioned model. This communication overhead can become a...