Speeding up MuJoCo 460x with JAX -
Speeding up MuJoCo 460x with JAX -
Speeding up MuJoCo 460x with JAX
An introduction to JAX and MJX for fast robotics simulation.
The Code
Setup
JAX Basics
How to Write Bad JAX
Basic JAX and MJX
Using JIT to Speed Things Up
How to Write Better JAX
Parallelism
Jitting the Whole Thing
Looping with Scan
A Full Data Collection Loop
Conclusion
Most roboticists know MuJoCo, Google’s simulation library for rigid-body robotics. Fewer have spent much time with JAX, Google’s numerical computing library for scientific computing and machine learning. JAX is fiddly, but for parallel simulation it can be outrageously fast.
I’m currently working on a basic world model, which means I need to collect a bunch of simulation data to train it. In this post I’ll show how JAX, via its MuJoCo backend MJX, gives us a neat way to do that data collection quickly. Here’s a companion Google Colab you can run to try things for yourself.
Comparison of MJX to MuJoCo over a number of parallel Cartpole environments. Note that the y-axis is log scale. Below 16 environments, MuJoCo wins, but it scales poorly for parallel simulation. Tested on a Google Colab L4 GPU.
The timings below are steady-state after jit compilation. I’m reporting amortised time per environment step: total rollout wall-clock time divided by n_steps * n_envs (and by n_runs in the final example). For the single-environment examples, n_envs = 1.
The Code
Setup
import jax<br>from mujoco_playground import registry
env_cfg = registry.get_default_config("CartpoleBalance")<br>env = registry.load("CartpoleBalance", config=env_cfg)
In this code I’m using mujoco_playground for a convenient MuJoCo/MJX environment. We load Cartpole — a basic environment offered in many RL codebases. The env object will be familiar to RL practitioners: It exposes a step method which takes an action and increments the simulation. Cartpole itself looks like this:
A simple Cartpole environment in MuJoCo, where the goal is to move the cart left and right to keep the pole balanced. This random policy is not doing a good job.
JAX Basics
As mentioned above, JAX largely resembles NumPy, which might make you wonder why you should care about it. What makes JAX special is its transforms. Transforms let you modify NumPy-style code to use a GPU for a big speedup. The three most important transforms for our purposes are:
jit: a ‘just-in-time’ compiler traces through your Python code and compiles it down to a faster representation which can run on GPU. Jit will make the first run of a function pretty slow while the compiler writes its new code, but subsequent calls are much faster.
vmap: short for ‘vectorising map’, it makes your code run in parallel so you can reap the benefits of using a GPU.
scan: ‘JAX-ifies’ Python ‘for’ loops, netting a speedup for sequential jobs like simulation.
How to Write Bad JAX
In the intro I mentioned that JAX is fiddly. It’s worth it, but it is absolutely fiddly. To illustrate that point, I’ll first write a naive implementation that looks like good old NumPy. Then we’ll JAX-ify our code step-by-step and watch the performance tick up.
Basic JAX and MJX
First let’s see what it looks like to run an episode in our environment. For these examples we only care about broad data collection, so I’ll just use random actions.
n_steps = 10<br># JAX relies on manual RNG management. This is annoying, but it means<br># everything is deterministic.<br>key = jax.random.key(0)<br>reset_key, key = jax.random.split(key)
# We reset the environment at the start to get it in a fresh state<br>state = env.reset(reset_key)
for t in range(n_steps):<br># Select a random Gaussian action<br>action_key, key = jax.random.split(key)<br>action = jax.random.normal(action_key)
# Step the environment<br>state = env.step(state, action)
Time per step: 1.4s. This is pretty simple, and hopefully fairly intuitive! It shows off one JAX quirk — manual RNG management via keys. Reusing a key gives you repeat randomness, so you need to do a painful dance of splitting keys every time you do something random. It’s irritating, but the determinism you get from this pays for itself in the long run. And it enables cool tricks like communicating massive amounts of data purely via single integer keys. There’s a catch, however: This is slow as sin. On my MacBook Pro, it takes 14.1 seconds for 10 steps. Let’s track that stat — seconds per step — to see how things improve.
Using JIT to Speed Things Up
As nice as our code is, we’re not making use of our wonderful JAX transforms; that’s why it’s so slow. For the first improvement, let’s jit:
n_steps = 1_000 # We can now run for more steps without dying of old age
# The only real change - wrap our env functions in jax.jit<br>reset = jax.jit(env.reset)<br>step = jax.jit(env.step)
key = jax.random.key(0)<br>reset_key, key = jax.random.split(key)
state = reset(reset_key)
for t in range(n_steps):<br>action_key, key = jax.random.split(key)<br>action =...