The Transformer: The Life of a Token

ai-epiphany1 pts0 comments

Inside the Transformer: The Life of a Token - Aleksa Gordić

← Back to blog<br>code]:bg-gray-200 [&_*:not(pre)>code]:text-red-600 [&_*:not(pre)>code]:px-1 [&_*:not(pre)>code]:py-0.5 [&_*:not(pre)>code]:rounded [&_*:not(pre)>code]:text-sm [&_*:not(pre)>code]:font-mono [&_ol]:text-gray-700 [&_ol]:list-decimal [&_ol]:list-inside [&_ol]:pl-6 [&_ol]:mb-6 [&_ol]:mt-3 [&_li]:mb-3 [&_ul]:list-disc [&_ul]:list-inside [&_ul]:pl-6 [&_ul]:space-y-1">In this post, I'll do a deep dive into the internals of a modern dense transformer [1]. I'll focus exclusively on the forward pass on a single GPU, as if we were about to perform a training step, while ignoring the backward pass and distributed systems details (in practice, large Transformers are sharded across multiple devices during both training and inference).<br>As a running example, I'll use the exact architecture of Rnj 1.5 - a model I worked on with my team at Ashish Vaswani's AI Lab (Essential AI Labs).<br>💡The team behind Rnj-1.5:<br>Rnj 1.5 could not have happened without an amazing group of people (sorted alphabetically):<br>Code pod: Adarsh Chaluvaraju, Devaansh Gupta, Yash Jain, Somanshu Singla, Saurabh Srivastava (tech lead), Anil Thomas<br>STEM pod: Aleksa Gordić (tech lead), Michael Pust, Tim Romanski, Ali Shehper, Kurt Smith (tech lead), Ameya Velingker<br>Infra pod: Mike Callahan, Philip Monk (tech lead), Khoi Nguyen (tech lead), Alok Tripathy, Yash Vanjani<br>Org: Divya Mansingka, Mohit Parmar, Peter Rushton<br>Research and Engineering Roadmap: Ashish Vaswani

We announced it this week, with weights released on Hugging Face.<br>It's a long-context follow-up to Rnj 1.0 [2] that extends the context window from 32k to 160k, scoring 79% on RULER on a 128k context window. This release also offers stronger coding abilities on a wider range of harnesses. See our model card for more details.<br>This post is structured into seven parts:<br>Transformer forward pass: high-level flow of a token<br>RMSNorm: the normalization layer<br>GeGLU MLP: GELU-gated feedforward block<br>MHA: multi-head self-attention<br>YaRN: positional embeddings for long context<br>Core Attention: global + block local<br>Tranformer math: FLOPs/token, cluster sizing, and more<br>In a follow-up post, I'll dive into conditional computation, focusing on sparse transformers (MoE).<br>Transformer forward pass<br>As a running example, assume we sample 2 "documents" from a dataset, with:<br>batch size = 1<br>sequence length = 16<br>document packing enabled

We'll trace how a token flows through the transformer and, along the way, unpack each component.<br>Let's start. Spend some time analyzing the following:<br>Figure 1: Tokenization stage

We tokenize the documents into sequences of integers, then pack the two documents into a single sequence.<br>For the scope of this blog post, the tokenizer is a black-box component that takes in text and maps it to a sequence of tokens, each represented by an integer ID. In practice, tokenizers are “trained” on a separate corpus of text using algorithms such as BPE, which learn a vocabulary by repeatedly merging frequent character or byte sequences. Good tokenizer design has several desirable properties; for example, representing digits as individual tokens can help with numerical reasoning.

Alongside the tokens, we construct two supporting structures:<br>inputs positions - used by the positional embedding module (YaRN)<br>segmentation mask - used in attention for masking<br>This is the preprocessing stage.<br>📝Side note:<br>For efficiency reasons, the data is chunked ahead of time, before training starts, and the data loader feeds these preprocessed structures directly into the training loop. At that point, we never deal with raw strings. The (Spark) data pipelines and the data loader could easily be separate blog posts.

Next, we use the input tokens to index into the embedding table.<br>You can think of the embedding table as the vocabulary of the LLM.<br>This indexing operation converts our sequence of integers into a sequence of 16 4096-dim bf16 vectors:<br>Figure 2: Embedding stage

📝Side note:<br>Special tokens don't naturally appear during tokenization - no text maps to token IDs >= 128,000. They're injected during training (and later used at inference) to improve performance (e.g. FIM, repo packing, etc.) or to enforce specific behaviors (e.g. end of generation / turn, tool calls).<br>Let's dig into FIM [3] (fill-in-the-middle) special tokens.<br>During (pre)training, we take a document, split it into prefix, middle (infix), and suffix, and construct a sequence of the form: FIM_PRE> prefix FIM_SUF> suffix FIM_MID> middle. The model is trained to predict the middle given the prefix and suffix. This capability can then be leveraged at inference time.<br>For example, imagine using Rnj 1.5 as an autocomplete model in your favorite IDE. Your cursor naturally splits the code into a prefix and suffix, with the middle missing. By inserting FIM tokens and ending with FIM_MID>, you prompt the model to generate a completion for the gap. These tokens help communicate intent to...

code tokens transformer token text training

Related Articles