Training a Simple World Model with Jax

ainch1 pts0 comments

World Model I

World Model I

How to build a basic world model plus MPC in JAX.

What Is a World Model?

Learning to Model the World

Network Definition

The Training Loop

Using MPC to Derive a Policy

World Model: Rollout!

Generating Actions

Deriving Our Policy

How Does It Do?

The Bigger Picture

What Is a World Model?

There are lots of competing definitions of world models, and frankly the term is becoming so diluted it’s increasingly useless. Fei-Fei Li’s world models focus on 3D visual reconstruction. Yann LeCun’s JEPAs learn and predict underlying latent representations from video. Genie 3 is more action-oriented (check out my write up here). Credible people like Kevin Murphy reckon LLMs constitute world models. For their part, Goldman Sachs think that… wait what? What bubble!?

Anyway, for my world model I’m going to take a more traditional Reinforcement Learning view - that means we’re talking about a state $s$, and an agent taking an action $a$. In that universe, a world model is fundamentally defined by predicting what the next state will be following a state-action pair:

\[f_\theta(s, a) = s'.\]

In this blog, we’ll train a simple model $f_\theta$ and use MPC to transmute it into a policy. We can do something cool here; by training a world model only on random data, we’ll derive a working policy without our data collection ever being designed for any particular task. In some sense, that’s what we do as humans. We think about what we want to get done, use our world knowledge, and squish them together to come up with a novel solution. We do that without needing to know the problem in advance, unlike typical RL solutions where the reward function is usually a core part of training. As per usual, the code is available at this notebook if you’d like to run it yourself.

Learning to Model the World

Network Definition

I covered fast data collection using MJX in my last post, so I won’t dwell on the details here. Our data collection will be random, looking something like this.

Your browser does not support the video tag.

Our data collection looks like this: task-agnostic random action selection. The agent controls the horizontal speed of the cart, either left or right.

We’ll store and access transitions $(s, a, s’)$. With data collected, we can train a basic neural network to predict $s’$ given the initial state-action pair. I’m using Google’s flax.nnx library for this because it works well with the JAX transforms we’re using elsewhere.

# Define one layer of the network.<br>class LayerBlock(nnx.Module):<br>"A single linear layer using batchnorm for stable training."<br>def __init__(<br>self,<br>in_features: int,<br>out_features: int,<br>activation_fn: Callable,<br>rngs: nnx.Rngs,<br>bn: bool = True,<br>):<br>self.layers = nnx.List(<br>nnx.Linear(in_features, out_features, rngs=rngs),<br>nnx.BatchNorm(out_features, rngs=rngs) if bn else nnx.identity,<br>activation_fn,

def __call__(<br>self, x: Shaped[Array, "... InFeatures"]<br>) -> Shaped[Array, "... OutFeatures"]:<br>for layer in self.layers:<br>x = layer(x)<br>return x

# Stick a few layers together for our full network.<br>class OneStepWorldModel(nnx.Module):<br>def __init__(self, rngs: nnx.Rngs):<br>self.layers = nnx.List(<br>LayerBlock(6, 32, activation_fn=nnx.swish, rngs=rngs, bn=True),<br>LayerBlock(32, 64, activation_fn=nnx.swish, rngs=rngs, bn=True),<br>LayerBlock(64, 32, activation_fn=nnx.swish, rngs=rngs, bn=True),<br># Skip batchnorm and an activation function for the output layer<br>LayerBlock(32, 5, activation_fn=nnx.identity, rngs=rngs, bn=False),

def __call__(<br>self,<br>obs: Shaped[Array, "... StateDim"],<br>action: Shaped[Array, "... ActionDim"],<br>) -> Shaped[Array, "... StateDim"]:<br># Stack the state and action into one array<br>x = jnp.concatenate([obs, action], axis=-1)<br>for layer in self.layers:<br>x = layer(x)

obs = x + obs # predict the delta<br># normalise the sin and cos components (axes 1,2) to the unit circle<br>sincos_norm = jnp.linalg.norm(obs[..., 1:3], axis=-1)<br>obs = obs.at[..., 1].set(obs[..., 1] / sincos_norm)<br>obs = obs.at[..., 2].set(obs[..., 2] / sincos_norm)<br>return obs

The Training Loop

To train the model, I’m using a basic loss function, which just computes the squared distance between the model prediction $f_\theta(s_t, a_t)$ and the true next state $s_{t+1}$. We should really normalise each dimension of the state for this calculation (say the cart velocity was 10x larger than everything else — without normalisation it will dominate the prediction error, and the model will overly focus on it). But the un-normalised loss is good enough for this demo.

\[\mathcal{L}(\theta) = || f_\theta(s_t, a_t) - s_{t+1} ||^2\]

We’ll combine our loss function with a fairly bog-standard nnx training loop. The big difference vs. PyTorch is the use of nnx.value_and_grad — we tell JAX to give us the gradient over a specific function, rather than computing a loss and calling loss.backward(). For more on Torch vs JAX, check out this post.

model = OneStepWorldModel(nnx.Rngs(0))

n_train_steps = 2000<br># Use a simple...

model rngs world action self training

Related Articles