Poverty Bayes: fitting million-parameter models for pennies with serverless MCMC

ckrapu1 pts0 comments

Poverty Bayes: fitting million-parameter models for pennies with serverless MCMC | Christopher Krapu %F0%9F%8C%81"> It’s a good time to be an applied probabilist. The deep learning revolution has led to tremendous improvements in the $ / flops department, and we Bayesians can easily hop on this train! During grad school, I used to spend nights and weekends babysitting MCMC runs on my GeForce Titan XP running in my bedroom (by the way, thank you NVIDIA Academic Grant Program) while simultaneously trying to keep the waste heat from cooking me as I slept. If you are a newcomer to this field, rejoice in the knowledge that all this suffering is a thing of the past. A slew of companies are rushing to the fore with user-friendly platforms for renting GPUs. For prototyping, I really enjoy working with Modal since I’m cheap and I’m too lazy to keep managing my own fleet.<br>In this post, I’ll show a workflow for using GPU-based inference on Modal for a model which is very large by the standards of Bayesian statisticians \((\vert \theta\vert \gt 10^6)\), by deploying to a datacenter GPU and renting it only for a short time.<br>Model & data<br>We’ll use synthetic data for this example.<br>I’ve chosen a hierarchical logistic regression for this post since it has a non-conjugate likelihood, appears commonly in practice, and can easily be assigned more parameters by increasing the number of covariates and/or the number of groups.<br>Let \(i\), \(g\), and \(k\) denote the indices over observations, groups, and covariates. Furthermore, let \(x_i \in \mathbb{R}^{20}\) be the covariate vector for observation \(i\), and let \(g_i \in \{1,\ldots,100000\}\) identify its group. The data-generating process uses population-level slopes \(\beta_k\), group intercept deviations \(\alpha_g\), and group slope deviations \(\gamma_{gk}\). The binary outcome is generated from<br>\[Y_i \sim \operatorname{Bernoulli}(p_i), \qquad \operatorname{logit}(p_i) = \alpha + \alpha_{g_i} + \sum_{k=1}^{K} x_{ik}(\beta_k + \gamma_{g_i k}).\] Essentially, this is a logistic regression with random slopes for 20 covariates and a random intercept for each of 100,000 groups in the data.<br>We’ll use a non-centered parameterization for the group effects. The prior specification is<br>\[\begin{aligned} \alpha &\sim \operatorname{Normal}(0, 1.5), \\ \beta_k &\sim \operatorname{Normal}(0, 1), \\ \sigma_\alpha &\sim \operatorname{HalfNormal}(1), \\ \sigma_{\gamma,k} &\sim \operatorname{HalfNormal}(0.5), \\ z_{\alpha,g} &\sim \operatorname{Normal}(0, 1), \\ z_{\gamma,gk} &\sim \operatorname{Normal}(0, 1), \\ \alpha_g &= \sigma_\alpha z_{\alpha,g}, \\ \gamma_{gk} &= \sigma_{\gamma,k} z_{\gamma,gk}. \end{aligned}\] The code below produces a synthetic dataset; we can control the overall sparsity of the response with the value of α_true.<br>import numpy as np

RANDOM_SEED = 827<br>rng = np.random.default_rng(RANDOM_SEED)

N = 1_000_000 # Number of data points<br>G = 100_000 # Number of groups<br>K = 20 # Number of covariates / features

group_idx = rng.integers(0, G, size=N, dtype=np.int64)<br>X = rng.normal(size=(N, K)).astype(np.float32)

α_true = np.float32(-1.0)<br>β_true = rng.normal(0.0, 0.45, size=K).astype(np.float32)<br>σ_α_true = np.float32(0.80)<br>σ_γ_true = rng.uniform(0.15, 0.35, size=K).astype(np.float32)<br>α_group_true = rng.normal(0.0, σ_α_true, size=G).astype(np.float32)<br>γ_group_true = rng.normal(0.0, σ_γ_true, size=(G, K)).astype(np.float32)

η = α_true + α_group_true[group_idx] + np.sum(X * (β_true + γ_group_true[group_idx]), axis=1)<br>p = 1.0 / (1.0 + np.exp(-η))<br>y = rng.binomial(1, p).astype(np.int64)

print(f"Average of y: {np.mean(y):.2f}")

Average of y: 0.36

Here, we define our model in PyMC. This is a fairly standard model definition without much nuance or many tricks.<br>import pymc as pm<br>import pytensor<br>import pytensor.tensor as pt

pytensor.config.floatX = "float32"

def build_model(X, y, group_idx):

X = np.asarray(X, dtype=np.float32)<br>y = np.asarray(y, dtype=np.int64)<br>group_idx = np.asarray(group_idx, dtype=np.int64)

N, K = X.shape<br>G = int(group_idx.max()) + 1<br>coords = {<br>"obs": np.arange(N),<br>"group": np.arange(G),<br>"covariate": np.arange(K),

with pm.Model(coords=coords) as model:<br>X_data = pm.Data("X", X, dims=("obs", "covariate"))<br>group_lookup = pt.as_tensor_variable(group_idx, name="group_lookup")

α = pm.Normal("α", np.float32(0.0), np.float32(1.5))<br>β = pm.Normal("β", np.float32(0.0), np.float32(1.0), dims="covariate")<br>σ_α = pm.HalfNormal("σ_α", np.float32(1.0))<br>σ_γ = pm.HalfNormal("σ_γ", np.float32(0.5), dims="covariate")

z_α = pm.Normal("z_α", np.float32(0.0), np.float32(1.0), dims="group")<br>z_γ = pm.Normal("z_γ", np.float32(0.0), np.float32(1.0), dims=("group", "covariate"))<br>α_group = σ_α * z_α<br>γ_group = σ_γ * z_γ

η = α + α_group[group_lookup] + pm.math.sum(X_data * (β + γ_group[group_lookup]), axis=1)<br>pm.Bernoulli("Y", logit_p=η, observed=y, dims="obs")

return model

To illustrate why this model could be challenging to fit using only a CPU, we profile the gradient evaluation...

float32 normal _true model operatorname group_idx

Related Articles