Speeding up MuJoCo 460x with Jax

ainch1 pts0 comments

Speeding up MuJoCo 460x with JAX -

Speeding up MuJoCo 460x with JAX -

Speeding up MuJoCo 460x with JAX

An introduction to JAX and MJX for fast robotics simulation.

The Code

Setup

JAX Basics

How to Write Bad JAX

Basic JAX and MJX

Using JIT to Speed Things Up

How to Write Better JAX

Parallelism

Jitting the Whole Thing

Looping with Scan

A Full Data Collection Loop

Conclusion

Most roboticists know MuJoCo, Google’s simulation library for rigid-body robotics. Fewer have spent much time with JAX, Google’s numerical computing library for scientific computing and machine learning. JAX is fiddly, but for parallel simulation it can be outrageously fast.

I’m currently working on a basic world model, which means I need to collect a bunch of simulation data to train it. In this post I’ll show how JAX, via its MuJoCo backend MJX, gives us a neat way to do that data collection quickly. Here’s a companion Google Colab you can run to try things for yourself.

Comparison of MJX to MuJoCo over a number of parallel Cartpole environments. Note that the y-axis is log scale. Below 16 environments, MuJoCo wins, but it scales poorly for parallel simulation. Tested on a Google Colab L4 GPU.

The timings below are steady-state after jit compilation. I’m reporting amortised time per environment step: total rollout wall-clock time divided by n_steps * n_envs (and by n_runs in the final example). For the single-environment examples, n_envs = 1.

The Code

Setup

import jax<br>from mujoco_playground import registry

env_cfg = registry.get_default_config("CartpoleBalance")<br>env = registry.load("CartpoleBalance", config=env_cfg)

In this code I’m using mujoco_playground for a convenient MuJoCo/MJX environment. We load Cartpole — a basic environment offered in many RL codebases. The env object will be familiar to RL practitioners: It exposes a step method which takes an action and increments the simulation. Cartpole itself looks like this:

A simple Cartpole environment in MuJoCo, where the goal is to move the cart left and right to keep the pole balanced. This random policy is not doing a good job.

JAX Basics

As mentioned above, JAX largely resembles NumPy, which might make you wonder why you should care about it. What makes JAX special is its transforms. Transforms let you modify NumPy-style code to use a GPU for a big speedup. The three most important transforms for our purposes are:

jit: a ‘just-in-time’ compiler traces through your Python code and compiles it down to a faster representation which can run on GPU. Jit will make the first run of a function pretty slow while the compiler writes its new code, but subsequent calls are much faster.

vmap: short for ‘vectorising map’, it makes your code run in parallel so you can reap the benefits of using a GPU.

scan: ‘JAX-ifies’ Python ‘for’ loops, netting a speedup for sequential jobs like simulation.

How to Write Bad JAX

In the intro I mentioned that JAX is fiddly. It’s worth it, but it is absolutely fiddly. To illustrate that point, I’ll first write a naive implementation that looks like good old NumPy. Then we’ll JAX-ify our code step-by-step and watch the performance tick up.

Basic JAX and MJX

First let’s see what it looks like to run an episode in our environment. For these examples we only care about broad data collection, so I’ll just use random actions.

n_steps = 10<br># JAX relies on manual RNG management. This is annoying, but it means<br># everything is deterministic.<br>key = jax.random.key(0)<br>reset_key, key = jax.random.split(key)

# We reset the environment at the start to get it in a fresh state<br>state = env.reset(reset_key)

for t in range(n_steps):<br># Select a random Gaussian action<br>action_key, key = jax.random.split(key)<br>action = jax.random.normal(action_key)

# Step the environment<br>state = env.step(state, action)

Time per step: 1.4s. This is pretty simple, and hopefully fairly intuitive! It shows off one JAX quirk — manual RNG management via keys. Reusing a key gives you repeat randomness, so you need to do a painful dance of splitting keys every time you do something random. It’s irritating, but the determinism you get from this pays for itself in the long run. And it enables cool tricks like communicating massive amounts of data purely via single integer keys. There’s a catch, however: This is slow as sin. On my MacBook Pro, it takes 14.1 seconds for 10 steps. Let’s track that stat — seconds per step — to see how things improve.

Using JIT to Speed Things Up

As nice as our code is, we’re not making use of our wonderful JAX transforms; that’s why it’s so slow. For the first improvement, let’s jit:

n_steps = 1_000 # We can now run for more steps without dying of old age

# The only real change - wrap our env functions in jax.jit<br>reset = jax.jit(env.reset)<br>step = jax.jit(env.step)

key = jax.random.key(0)<br>reset_key, key = jax.random.split(key)

state = reset(reset_key)

for t in range(n_steps):<br>action_key, key = jax.random.split(key)<br>action =...

random mujoco step code environment simulation

Related Articles