Flax debugging: making a hash of things :: Giles' blog
el.dataset.currentDropdown = '')<br>}">
Giles' blog
About
Contact
Archives
Categories
Blogroll
June 2026 (5)
May 2026 (2)
April 2026 (11)
March 2026 (3)
February 2026 (4)
January 2026 (4)
December 2025 (1)
November 2025 (3)
October 2025 (9)
September 2025 (3)
August 2025 (5)
July 2025 (1)
June 2025 (2)
May 2025 (3)
April 2025 (2)
March 2025 (7)
February 2025 (10)
January 2025 (6)
December 2024 (7)
September 2024 (1)
August 2024 (2)
July 2024 (2)
May 2024 (2)
April 2024 (2)
February 2024 (2)
April 2023 (1)
March 2023 (2)
September 2022 (1)
February 2022 (1)
November 2021 (1)
March 2021 (1)
February 2021 (2)
August 2019 (1)
November 2018 (1)
May 2017 (1)
December 2016 (1)
April 2016 (1)
August 2015 (1)
December 2014 (1)
August 2014 (1)
March 2014 (1)
December 2013 (1)
October 2013 (3)
September 2013 (4)
August 2013 (2)
July 2013 (1)
June 2013 (1)
February 2013 (1)
October 2012 (1)
June 2012 (1)
May 2012 (1)
April 2012 (1)
February 2012 (1)
October 2011 (1)
June 2011 (1)
May 2011 (1)
April 2011 (1)
March 2011 (1)
February 2011 (1)
January 2011 (1)
December 2010 (3)
November 2010 (1)
October 2010 (1)
September 2010 (1)
August 2010 (1)
July 2010 (1)
May 2010 (3)
April 2010 (1)
March 2010 (2)
February 2010 (3)
January 2010 (4)
December 2009 (2)
November 2009 (5)
October 2009 (2)
September 2009 (2)
August 2009 (3)
July 2009 (1)
May 2009 (1)
April 2009 (1)
March 2009 (5)
February 2009 (5)
January 2009 (5)
December 2008 (3)
November 2008 (7)
October 2008 (4)
September 2008 (2)
August 2008 (1)
July 2008 (1)
June 2008 (1)
May 2008 (1)
April 2008 (1)
January 2008 (4)
December 2007 (3)
March 2007 (3)
February 2007 (1)
January 2007 (2)
December 2006 (4)
November 2006 (18)
AI (85)
TIL deep dives (75)
Python (72)
LLM from scratch (46)
Resolver One (34)
PyTorch (21)
TIL (21)
Blogkeeping (18)
PythonAnywhere (17)
Linux (16)
Startups (15)
Hugging Face (13)
NSLU2 offsite backup project (13)
Gadgets (12)
Funny (11)
Musings (11)
Finance (10)
Fine-tuning LLMs (10)
C (9)
Personal (8)
Robotics (8)
Website design (8)
3D (5)
JAX (5)
Rants (5)
Cryptography (4)
JavaScript (4)
Music (4)
Oddities (4)
Quick links (4)
Talks (4)
Dirigible (3)
Eee (3)
Memes (3)
Politics (3)
Django (2)
GPU Computing (2)
LaTeX (2)
MathML (2)
OLPC XO (2)
Retro Language Models (2)
Space (2)
VoIP (2)
Copyright (1)
Golang (1)
Microprojects (1)
Raspberry Pi (1)
Software development tools (1)
Agile Abstractions
Astral Codex Ten
:: (Bloggable a) => a -> IO ()
David Friedman's Substack
Econ & Energy
Entrepreneurial Geekiness
For some value of "Magic"
Hackaday
kaleidic.ai newsletter
Knowing.NET
Language Log
Millennium Hand
ntoll.org
Obey the Testing Goat!
PK
PythonAnywhere News
Simon Willison's Weblog
Societive
Software Deviser
Some opinions, held with varying degrees of certainty
tartley.com
Flax debugging: making a hash of things
Posted on 17 June 2026
in
AI,
TIL,
JAX,
Python
I was debugging an issue with a JAX/Flax NNX training loop the other day, and found a neat<br>little trick to help debug it. Specifically, I wanted to see if the issue<br>was with my model, my loss function, my optimiser settings, or the "plumbing" of the<br>training loop itself -- were gradients actually coming through and being applied to the parameters?
I could print out the loss and the gradients, but printing out the parameters to see<br>if they were changing was unhelpful -- any given update might only change a small number<br>of parameters, or might change them such a small amount that I'd not notice -- especially given<br>that the model had 77 million of them!
Let's take a look.
The world's worst LLM
I am building an LLM from scratch in JAX and Flax NNX, and at this stage I'm trying to<br>get the training loop right. As a simple test, I've just implemented the "shell" of<br>the LLM -- the token embeddings on the input side, and the final linear layer for an<br>output head, wired directly together. My plan was to train that so that given a sequence, instead of predicting<br>next tokens for each position, it would "predict" the sequence itself -- that is, I might<br>train it with the input
The fat cat sat on the mat
...and the target
The fat cat sat on the mat
...rather than the normal setup for an LLM, where you feed it
The fat cat sat on the
...and give it targets of
fat cat sat on the mat
So, in LLM terms, I'd be training a model to project from vocab space to a learned embedding<br>space where each token had a distinct-enough embedding for the output head to be able<br>to reliably project back to logits in vocab space. There's<br>a bit of background here if that was all Greek to you.
Here's the core part of the code I was working with, the train_step function, which<br>seems to be the traditional JAX name for the JITted part of your code that does the<br>forward pass through the model,...