ベイズ推定の実装#
確率モデルをコードとして書く考え方(確率的プログラミング Probabilistic Programming)をとることが多い。
具体的なツールとしては Stan や NumPyroなどを使う
例#
線形回帰モデル
\[
y_i = \alpha+X_i \beta + e_i, \quad e_i \sim \mathcal{N}(0, \sigma)
\]
について、
\[\begin{split}
\begin{aligned}
y_i & \sim \mathcal{N}\left(\alpha+X_i \beta, \sigma\right) \\
\alpha & \sim \mathcal{N}(0,1) \\
\beta_j & \sim \mathcal{N}(0,1) \\
\sigma & \sim \operatorname{HalfNormal}(1) \quad(\sigma \geq 0)
\end{aligned}
\end{split}\]
# サンプルデータ生成
import numpy as np
rng = np.random.default_rng(0)
N, K = 200, 3
X = rng.normal(size=(N, K))
alpha_true = 1.0
beta_true = np.array([0.5, -1.2, 0.3])
sigma_true = 0.7
y = alpha_true + X @ beta_true + rng.normal(scale=sigma_true, size=N)
Stan#
data {
int<lower=0> N; // number of data items
int<lower=0> K; // number of predictors
matrix[N, K] x; // predictor matrix
vector[N] y; // outcome vector
}
parameters {
real alpha; // intercept
vector[K] beta; // coefficients for predictors
real<lower=0> sigma; // error scale
}
model {
alpha ~ normal(0, 1);
beta ~ normal(0, 1);
sigma ~ normal(0, 1); // <lower=0> にしているので半正規分布
y ~ normal(x * beta + alpha, sigma); // likelihood
}
cmdstanpyパッケージを使う場合の例
from cmdstanpy import CmdStanModel
model = CmdStanModel(stan_file="hoge.stan")
data = {"N": N, "K": K, "X": X, "y": y}
fit = model.sample(
data=data,
chains=4,
iter_warmup=1000,
iter_sampling=1000,
seed=0,
)
df = fit.draws_pd() # pandas DataFrame
print(df[["alpha", "beta[1]", "beta[2]", "beta[3]", "sigma"]].describe())
NumPyro#
JAXという高速な科学計算ライブラリを使っている
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
def model(X, y=None):
N, K = X.shape
alpha = numpyro.sample("alpha", dist.Normal(0.0, 1.0))
beta = numpyro.sample("beta", dist.Normal(0.0, 1.0).expand([K]))
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
mu = alpha + jnp.dot(X, beta)
numpyro.sample("y", dist.Normal(mu, sigma), obs=y)
nuts = NUTS(model)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False)
mcmc.run(jax.random.PRNGKey(0), X=jnp.array(X), y=jnp.array(y))
samples = mcmc.get_samples(group_by_chain=False)
# samples["beta"] shape: (num_draws, K)
print({k: (v.mean(0), v.std(0)) for k, v in samples.items() if k in ["alpha","sigma"]})
print("beta mean:", samples["beta"].mean(0))
print("beta sd :", samples["beta"].std(0))
/tmp/ipykernel_8259/2230612519.py:16: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False)
{'alpha': (Array(0.9722434, dtype=float32), Array(0.05137558, dtype=float32)), 'sigma': (Array(0.710881, dtype=float32), Array(0.03627468, dtype=float32))}
beta mean: [ 0.43488577 -1.2650281 0.336224 ]
beta sd : [0.05001275 0.05104848 0.05170776]
PyMC#
NumPyroと近い書き味だが、サンプラーを他のものにすることもできる
特徴:MCMCサンプラーを選べる#
Python NUTS sampler (デフォルト、NumPyroより速いことも)
NumPyro JAX NUTS sampler
BlackJAX NUTS sampler(大規模データで速いらしい)
Nutpie NUTS sampler(Rustで書かれていてJAXくらい速いらしい)
pm.sample(nuts_sampler="blackjax")
Faster Sampling with JAX and Numba — PyMC example gallery
import pymc as pm
import numpy as np
with pm.Model() as m:
alpha = pm.Normal("alpha", mu=0.0, sigma=5.0)
beta = pm.Normal("beta", mu=0.0, sigma=2.0, shape=K)
sigma = pm.HalfNormal("sigma", sigma=2.0)
mu = alpha + pm.math.dot(X, beta)
y_obs = pm.Normal("y", mu=mu, sigma=sigma, observed=y)
idata = pm.sample(
draws=1000,
tune=1000,
chains=4,
random_seed=0,
target_accept=0.8,
)
# ArviZでまとめて比較しやすい
import arviz as az
display(az.summary(idata, var_names=["alpha","beta","sigma"]))
Initializing NUTS using jitter+adapt_diag...
/home/runner/work/notes/notes/.venv/lib/python3.10/site-packages/pytensor/link/c/cmodule.py:2968: UserWarning: PyTensor could not link to a BLAS installation. Operations that might benefit from BLAS will be severely degraded.
This usually happens when PyTensor is installed via pip. We recommend it be installed via conda/mamba/pixi instead.
Alternatively, you can use an experimental backend such as Numba or JAX that perform their own BLAS optimizations, by setting `pytensor.config.mode == 'NUMBA'` or passing `mode='NUMBA'` when compiling a PyTensor function.
For more options and details see https://pytensor.readthedocs.io/en/latest/troubleshooting.html#how-do-i-configure-test-my-blas-library
warnings.warn(
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [alpha, beta, sigma]
/home/runner/.local/share/uv/python/cpython-3.10.19-linux-x86_64-gnu/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
/home/runner/.local/share/uv/python/cpython-3.10.19-linux-x86_64-gnu/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.
| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| alpha | 0.975 | 0.051 | 0.879 | 1.071 | 0.001 | 0.001 | 5105.0 | 3164.0 | 1.0 |
| beta[0] | 0.437 | 0.050 | 0.343 | 0.529 | 0.001 | 0.001 | 5539.0 | 3541.0 | 1.0 |
| beta[1] | -1.267 | 0.051 | -1.363 | -1.172 | 0.001 | 0.001 | 7651.0 | 3447.0 | 1.0 |
| beta[2] | 0.338 | 0.052 | 0.243 | 0.437 | 0.001 | 0.001 | 5717.0 | 3368.0 | 1.0 |
| sigma | 0.712 | 0.038 | 0.646 | 0.785 | 0.001 | 0.001 | 5299.0 | 3283.0 | 1.0 |