Jax Back Ends and Devices

gpjt1 pts0 comments

JAX backends and devices :: Giles' blog

el.dataset.currentDropdown = '')<br>}">

Giles' blog

About

Contact

Archives

Categories

Blogroll

June 2026 (2)

May 2026 (2)

April 2026 (11)

March 2026 (3)

February 2026 (4)

January 2026 (4)

December 2025 (1)

November 2025 (3)

October 2025 (9)

September 2025 (3)

August 2025 (5)

July 2025 (1)

June 2025 (2)

May 2025 (3)

April 2025 (2)

March 2025 (7)

February 2025 (10)

January 2025 (6)

December 2024 (7)

September 2024 (1)

August 2024 (2)

July 2024 (2)

May 2024 (2)

April 2024 (2)

February 2024 (2)

April 2023 (1)

March 2023 (2)

September 2022 (1)

February 2022 (1)

November 2021 (1)

March 2021 (1)

February 2021 (2)

August 2019 (1)

November 2018 (1)

May 2017 (1)

December 2016 (1)

April 2016 (1)

August 2015 (1)

December 2014 (1)

August 2014 (1)

March 2014 (1)

December 2013 (1)

October 2013 (3)

September 2013 (4)

August 2013 (2)

July 2013 (1)

June 2013 (1)

February 2013 (1)

October 2012 (1)

June 2012 (1)

May 2012 (1)

April 2012 (1)

February 2012 (1)

October 2011 (1)

June 2011 (1)

May 2011 (1)

April 2011 (1)

March 2011 (1)

February 2011 (1)

January 2011 (1)

December 2010 (3)

November 2010 (1)

October 2010 (1)

September 2010 (1)

August 2010 (1)

July 2010 (1)

May 2010 (3)

April 2010 (1)

March 2010 (2)

February 2010 (3)

January 2010 (4)

December 2009 (2)

November 2009 (5)

October 2009 (2)

September 2009 (2)

August 2009 (3)

July 2009 (1)

May 2009 (1)

April 2009 (1)

March 2009 (5)

February 2009 (5)

January 2009 (5)

December 2008 (3)

November 2008 (7)

October 2008 (4)

September 2008 (2)

August 2008 (1)

July 2008 (1)

June 2008 (1)

May 2008 (1)

April 2008 (1)

January 2008 (4)

December 2007 (3)

March 2007 (3)

February 2007 (1)

January 2007 (2)

December 2006 (4)

November 2006 (18)

AI (83)

TIL deep dives (75)

Python (71)

LLM from scratch (46)

Resolver One (34)

PyTorch (21)

Blogkeeping (18)

TIL (18)

PythonAnywhere (17)

Linux (16)

Startups (15)

Hugging Face (13)

NSLU2 offsite backup project (13)

Funny (11)

Gadgets (11)

Musings (11)

Finance (10)

Fine-tuning LLMs (10)

C (9)

Personal (8)

Robotics (8)

Website design (8)

3D (5)

Rants (5)

Cryptography (4)

JavaScript (4)

Music (4)

Oddities (4)

Quick links (4)

Talks (4)

Dirigible (3)

Eee (3)

JAX (3)

Memes (3)

Politics (3)

Django (2)

GPU Computing (2)

LaTeX (2)

MathML (2)

OLPC XO (2)

Retro Language Models (2)

Space (2)

VoIP (2)

Copyright (1)

Golang (1)

Microprojects (1)

Raspberry Pi (1)

Software development tools (1)

Agile Abstractions

Astral Codex Ten

:: (Bloggable a) => a -> IO ()

David Friedman's Substack

Econ & Energy

Entrepreneurial Geekiness

For some value of "Magic"

Hackaday

kaleidic.ai newsletter

Knowing.NET

Language Log

Millennium Hand

ntoll.org

Obey the Testing Goat!

PK

PythonAnywhere News

Simon Willison's Weblog

Societive

Software Deviser

Some opinions, held with varying degrees of certainty

tartley.com

JAX backends and devices

Posted on 5 June 2026

in

JAX,

TIL,

Python

There's nothing like writing your own code with a framework to clarify how things<br>fit together! Continuing with my port of my PyTorch LLM code to<br>JAX, I wanted to load up a large dataset:<br>the 10,248,871,837 16-bit unsigned integers in the train split of<br>gpjt/fineweb-gpt2-tokens.<br>That's just over 19GiB of data.

from safetensors.flax import load_file<br>...<br>full_dataset = load_file(dataset_dir / f"train.safetensors")["tokens"]

When I ran that, I got a CUDA out-of-memory error:

jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 19.09GiB.

That makes sense! The allocation it was trying to do is exactly the size of<br>the data I was trying to load. I have an RTX 3090 with 24 GiB, but some is already used up by<br>the OS, various apps, and a model that the code creates earlier on.

But in PyTorch land, I was used to things being loaded into RAM by default, and only<br>moved over to the GPU when I asked it to do that. JAX was clearly loading to the GPU<br>by default. How could I stop it from doing that for this case? The load into the GPU<br>was happening inside Safetensors, in code I couldn't directly control.

Understanding how to do it helped me understand a little bit more about JAX.

JAX has a function that looks relevant: jax.devices.<br>Without reading the docs, let's try running it. In my virtualenv, with the jax[cuda13]<br>package installed, I get this:

In [1]: import jax

In [2]: all_devices = jax.devices()

In [3]: all_devices<br>Out[3]: [CudaDevice(id=0)]

That seems a bit weird! I do indeed have a CUDA device, but I also have a CPU, obviously.<br>Why isn't it showing up?

Running the same code in another virtualenv, with just jax installed -- no CUDA -- gets this:

In [1]: import jax

In [2]: all_devices = jax.devices()<br>An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

In [3]: all_devices<br>Out[3]: [CpuDevice(id=0)]

OK, so...

february april december march august june

Related Articles