Radix Top-K: finding the top-k elements in an array without sorting

matt_d1 pts0 comments

Radix Top-K – simons blog

Radix Top-K

10 Jun, 2026

Radix Top-K is an algorithm for finding the top-k elements in an array without sorting the full array.

For simplicity, assume the values are unsigned integers. The same idea can be extended to other representations.

Initial setup: Choose TOP_K and BITS_PER_ITER, and assume every element can be represented with NUM_BITS bits.

Then iteratively apply the following procedure:

Extract the next BITS_PER_ITER bits from all current candidates, starting from the most significant bits.

Count how many candidates fall into each bucket: 0, 1, ..., 2^{BITS_PER_ITER} - 1.

Perform an inclusive scan over the bucket counts.

Let K_remaining be TOP_K minus the number of elements already known to be in the top-k. Select the first bucket index i where inclusive_scan[i] >= K_remaining. All elements in buckets j are guaranteed to be in the top-k. Elements in bucket i remain candidates for the next round. Elements in buckets j > i are discarded.

Repeat with the new candidates as input, updating K_remaining by subtracting the number of elements already guaranteed to be in the top-k.

After all bit chunks have been processed, if more candidates remain than open top-k slots, keep only as many as needed. This can happen when multiple values are tied at the boundary.

As written, this finds the TOP_K smallest values. For TOP_K largest values, reverse the bucket order.

Simple description in an image:

Minimal script to reproduce:

import torch

TOP_K = 4<br>NUM_BITS = 4<br>BITS_PER_ITER = 2<br>NUM_BUCKETS = 2**BITS_PER_ITER

def iteration(<br>iter_idx,<br>current_topk,<br>current_topk_idxs,<br>next_candidates,<br>next_candidate_idxs,<br>):<br>print(f"\nITER = {iter_idx}")

num_shift_right = NUM_BITS - (iter_idx + 1) * BITS_PER_ITER

shifted = torch.bitwise_right_shift(next_candidates, num_shift_right)<br>shifted = torch.bitwise_and(shifted, NUM_BUCKETS - 1)<br>print(f"{shifted=}")

hist = torch.bincount(shifted, minlength=NUM_BUCKETS)<br>print(f"{hist=}")

inclusive_scan = torch.cumsum(hist, dim=0)<br>print(f"{inclusive_scan=}")

num_needed = TOP_K - current_topk.numel()<br>mask = inclusive_scan >= num_needed<br>border = mask.float().argmax()<br>print(f"{border=}")

idxs_current_topk = shifted border<br>idxs_next_candidates = shifted == border

current_topk = torch.cat([current_topk, next_candidates[idxs_current_topk]])<br>current_topk_idxs = torch.cat(<br>[current_topk_idxs, next_candidate_idxs[idxs_current_topk]]

next_candidates = next_candidates[idxs_next_candidates]<br>next_candidate_idxs = next_candidate_idxs[idxs_next_candidates]

print(f"{current_topk=}")<br>print(f"{current_topk_idxs=}")<br>print(f"{next_candidates=}")<br>print(f"{next_candidate_idxs=}")

return current_topk, current_topk_idxs, next_candidates, next_candidate_idxs

if __name__ == "__main__":<br>x = torch.tensor([12, 4, 1, 8, 6, 5, 13, 0, 14], device="cuda")<br>print(f"{x=}")

current_topk = torch.empty(0, dtype=x.dtype, device=x.device)<br>current_topk_idxs = torch.empty(0, dtype=torch.long, device=x.device)

next_candidates = x<br>next_candidate_idxs = torch.arange(x.numel(), device=x.device)

num_iters = NUM_BITS // BITS_PER_ITER

for iter_idx in range(num_iters):<br>current_topk,<br>current_topk_idxs,<br>next_candidates,<br>next_candidate_idxs,<br>) = iteration(<br>iter_idx,<br>current_topk,<br>current_topk_idxs,<br>next_candidates,<br>next_candidate_idxs,

num_remaining = TOP_K - current_topk.numel()

final_topk = torch.cat([current_topk, next_candidates[:num_remaining]])<br>final_topk_idxs = torch.cat(<br>[current_topk_idxs, next_candidate_idxs[:num_remaining]]

print(f"\n{final_topk=}")<br>print(f"{final_topk_idxs=}")

I hope this small note helps others to learn about Radix TopK Select.

torch print current_topk next_candidates next_candidate_idxs current_topk_idxs

Related Articles