Why Distributed Training Is Hard: DTensor and the Costs of Abstraction

nielka2 pts0 comments

Runway News | Why Distributed Training Is Hard: DTensor, Correctness and the Costs of Abstraction

Enterprise SalesLog inTry Runway

Meet Runway Agent — your new creative partner. Get 50% off Pro when you use code RUNWAY50. Offer ends Monday. Terms.<br>Try Now

Why Distributed Training Is Hard: DTensor, Correctness and the Costs of Abstraction<br>May 18, 2026<br>by Wei Zhang, Research Engineering

DTensor makes distributed training correct by attaching placement metadata to every tensor. At scale it can also introduce costs that quietly erode throughput unless you design around them.

Why Distributed Training Is Hard

When you shard a tensor across a process group, every gradient that flows back through that shard has to match what you would have gotten on a single GPU. Getting this right manually means scattering collectives through the model, managing placement assumptions inside operators, and maintaining one-off codepaths for FSDP, tensor parallelism and pipeline parallelism. It is surprisingly easy to get wrong, and the bugs are almost always silent.

DTensor (PyTorch's Distributed Tensor) attempts to unify these concerns. Every tensor carries a small piece of metadata describing its placement: Replicate, Shard(dim), or Partial(sum). Operators then propagate placements automatically and insert the right collective operations when tensors need to move between layouts.

In theory, that gives you cleaner abstractions and safer scaling. In practice it solves one class of problems and creates another.

Four Attempts to Parallelize a Three-Line Module

The cleanest way to motivate DTensor is to try the alternative. Consider this toy diffusion transformer modulation module. Each token belongs to one sample in the batch, and for every sample we have a conditioning embedding (timestep, class label, text features, …) that needs to modulate that sample's tokens. The module projects the conditioning into a per-channel scale and multiplies it into the token activations. This is a simplified version of the AdaLN modulation pattern (without shift and normalization):

class Modulation(torch.nn.Module):<br>def __init__(self, hidden_dim: int):<br>super().__init__()<br># Learned projection: conditioning embedding -> per-channel scale.<br>self.weight = torch.nn.Parameter(<br>torch.randn(hidden_dim, hidden_dim, device=torch.cuda.current_device())

def forward(<br>self,<br>tokens: torch.Tensor, # [num_tokens, hidden_dim]<br>cond: torch.Tensor, # [num_samples, hidden_dim]<br>sample_ids: torch.Tensor, # [num_tokens] -- which sample each token belongs to<br>) -> torch.Tensor:<br># 1. One scale vector per sample.<br>per_sample_scale = torch.nn.functional.linear(cond, self.weight)<br># 2. Broadcast each sample's scale out to its tokens.<br>per_token_scale = per_sample_scale.index_select(0, sample_ids)<br># 3. Modulate.<br>return per_token_scale * tokens

The goal: shard tokens across a process group, compute locally, gather the result back, and produce four things that match the single-GPU baseline exactly: the forward result and the gradients on tokens, cond, and self.weight.

Getting the forward result right is easy. Getting the gradients right is not.

Attempt 1: torch.chunk and all_gather

The obvious first try: split tokens with torch.chunk, compute, all-gather, concatenate. The forward result is correct. But every gradient is wrong!

The problem is the backward of torch.chunk. Locally, it looks fine: it places the incoming gradient into the corresponding slice of the output and zero-fills the rest. With four tokens on two ranks, what each rank sees in tokens.grad after backward is:

rank 0: tokens.grad = [g0, g1, 0, 0]<br>rank 1: tokens.grad = [ 0, 0, g2, g3]

From rank 0's perspective this is correct: rank 0 never touched the second half of tokens, so it has no gradient to contribute there. But in the distributed setting we need the full gradient on every rank, and chunk has no idea other ranks exist. Single-GPU ops are oblivious to other ranks, and that obliviousness is the entire source of every bug in this section.

Attempt 2: a custom scatter

We replace torch.chunk with a custom autograd function whose backward all-gathers the partial gradients and concatenates them. Now tokens.grad is consistent across ranks.

It is also exactly twice the baseline. With TP=2, all_gather's backward calls reduce_scatter: sum across ranks, then split. But the upstream gradient is identical on both ranks (the loss is computed on the gathered replicated output), so summing doubles it:

reduce: [o0, o1, o2, o3] + [o0, o1, o2, o3] = [2*o0, 2*o1, 2*o2, 2*o3]<br>scatter: rank 0 gets [2*o0, 2*o1], rank 1 gets [2*o2, 2*o3]<br>correct: rank 0 gets [o0, o1], rank 1 gets [o2, o3]

Every value is TP_world_size * x instead of x. The root cause is a mismatch: our custom scatter's backward does all-gather-then-concat, so the all_gather in the forward is going from sharded to replicated. Its backward should be a plain chunk (each rank takes its slice), not reduce_scatter. PyTorch ships reduce_scatter because that's...

tokens torch rank tensor distributed dtensor

Related Articles