Making Dr GRPO go brrr | Rubber Duck as a Service
Making Dr GRPO go brrr
02 Jun, 2026
I wrote a fused decode-attention kernel for an RL training loop, got it 2.2× faster than the SDPA path it replaces at the microbenchmark level, dropped it into HuggingFace's generate, and watched the decode step get nearly 3× slower. The kernel was doing exactly what the microbench said it would. The integration broke an auto-compile path that the baseline was quietly benefiting from. This post is how I got there, what the gap actually was, and what closing it would have cost.
The wider context: this is the writeup of a project to RL-train a small open source model on GSM8K and write CuteDSL kernels for whichever paths dominate. The concrete setup is Qwen2.5-0.5B-Instruct, Dr. GRPO, a single A10G. The post covers two things: building the training loop from scratch (and squeezing 4.8× out of the rollout phase before any kernel work), and then writing the kernel above for the path that still dominated. Most of what follows is what those two facts look like sitting next to each other.
What is RL post-training, and why is it slow<br>In RL post-training for LLMs, you have a policy (the model), a verifier (something that scores outputs), and a loop that pushes the policy to produce higher-scoring outputs. For a math task like GSM8K, the verifier is just a regex that pulls the final number out of the model's response and compares it to the ground truth.
Each training step has two phases.
Rollout. Sample a prompt. Generate G completions from the current policy. Score them. Compute advantages.
Update. For K inner epochs: forward pass through the policy, compute the GRPO loss against the rewards, backprop, optimizer step.
Rollout dominates wall time. The reason is structural. Update is one big batched forward pass over (B*G, P+C) tokens, then a backward and a step. That's three GPU calls. Rollout is model.generate, which is a sequential decode loop that runs one forward pass per generated token, with each pass operating on (B*G, 1, hidden) plus a growing KV cache. Per-token compute is small, but you do it max_new_tokens times in serial. Even with KV cache and batching, you can't parallelize across the time dimension because each token depends on the last.
So most of the time, the GPU is doing many small forwards instead of a few big ones. That's the shape of the problem and that's what kernel work has to address.
PPO<br>PPO is a policy gradient method. You collect a rollout from the current policy, then run K epochs of mini-batch updates on that same rollout. Vanilla policy gradient is on-policy: collect a batch, do one update, throw the data away. PPO lets you reuse the same rollout for K epochs, which is the whole reason it exists, by clipping the importance ratio so the policy can't drift too far from the one that generated the data.
The ratio is
rt(θ)=πθ(at|st)πθold(at|st)
If rt=1 nothing changed. If rt>1 the new policy made the action more likely. The clipped objective is
LCLIP(θ)=피^t[min(rtA^t, clip(rt,1−ϵ,1+ϵ)A^t)]
The min picks the more conservative of the two surrogates, so PPO can improve, but not too much in one step.
Classical PPO also has a value network that estimates V(st), with the advantage computed as A^t=Rt−Vϕ(st) (often via GAE).
GRPO<br>GRPO drops the value network. Instead of asking "is this output good?" it asks "is this output better than the others I sampled for the same prompt?".
The pipeline:
Sample G completions for the same prompt
Score them with a verifier
Compute the advantage as Ai=(Ri−μ)/σ inside the group
Apply the same PPO clipped objective
No critic at all
The whole machinery of estimating V and computing GAE goes away because the group itself acts as the baseline.
Dr. GRPO<br>GRPO has two bias problems.
Length bias. The original loss averages per-response by 1|oi|. When Ai0, longer responses get a weaker per-token penalty. The model learns "if I'm going to be wrong, be wrong at length." Output length drifts upward over training even when quality does not improve.
Difficulty bias. Dividing by σ inside a group amplifies gradients on prompts with small std (very easy or very hard ones). Medium-difficulty groups, where the most useful learning signal lives, get under-weighted.
Dr. GRPO removes both denominators:
Ai=Ri−μ
and uses token-sum aggregation instead of per-response mean. The clipped objective stays the same.
LDr.GRPO(θ)=1G∑i=1G∑t=1|oi|min(ri,tAi,clip(ri,t,1−ϵ,1+ϵ)Ai)Two deletions, no other changes.
In pseudo-code, the whole thing looks like this:
for step in range(num_steps):<br># rollout<br>prompts = sample(dataset, batch_size)<br>completions = policy.generate(prompts, num_samples=G)<br>rewards = verifier(completions)<br>advantages = rewards - group_mean(rewards) # no std division
old_logprobs = policy.logprobs(completions).detach()<br>ref_logprobs = ref_policy.logprobs(completions).detach()
# update<br>for _ in range(K):<br>logprobs = policy.logprobs(completions)<br>ratio = exp(logprobs -...