ベイズ推定の実装

ベイズ推定の実装#

確率モデルをコードとして書く考え方(確率的プログラミング Probabilistic Programming)をとることが多い。

具体的なツールとしては Stan や NumPyroなどを使う

#

線形回帰モデル

\[ y_i = \alpha+X_i \beta + e_i, \quad e_i \sim \mathcal{N}(0, \sigma) \]

について、

\[\begin{split} \begin{aligned} y_i & \sim \mathcal{N}\left(\alpha+X_i \beta, \sigma\right) \\ \alpha & \sim \mathcal{N}(0,1) \\ \beta_j & \sim \mathcal{N}(0,1) \\ \sigma & \sim \operatorname{HalfNormal}(1) \quad(\sigma \geq 0) \end{aligned} \end{split}\]
# サンプルデータ生成
import numpy as np

rng = np.random.default_rng(0)
N, K = 200, 3
X = rng.normal(size=(N, K))

alpha_true = 1.0
beta_true = np.array([0.5, -1.2, 0.3])
sigma_true = 0.7

y = alpha_true + X @ beta_true + rng.normal(scale=sigma_true, size=N)

Stan#

data {
  int<lower=0> N;   // number of data items
  int<lower=0> K;   // number of predictors
  matrix[N, K] x;   // predictor matrix
  vector[N] y;      // outcome vector
}
parameters {
  real alpha;           // intercept
  vector[K] beta;       // coefficients for predictors
  real<lower=0> sigma;  // error scale
}
model {
  alpha ~ normal(0, 1);
  beta  ~ normal(0, 1);
  sigma ~ normal(0, 1); // <lower=0> にしているので半正規分布
  y ~ normal(x * beta + alpha, sigma);  // likelihood
}

cmdstanpyパッケージを使う場合の例

from cmdstanpy import CmdStanModel
model = CmdStanModel(stan_file="hoge.stan")

data = {"N": N, "K": K, "X": X, "y": y}
fit = model.sample(
    data=data,
    chains=4,
    iter_warmup=1000,
    iter_sampling=1000,
    seed=0,
)

df = fit.draws_pd()  # pandas DataFrame
print(df[["alpha", "beta[1]", "beta[2]", "beta[3]", "sigma"]].describe())

NumPyro#

JAXという高速な科学計算ライブラリを使っている

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def model(X, y=None):
    N, K = X.shape
    alpha = numpyro.sample("alpha", dist.Normal(0.0, 1.0))
    beta  = numpyro.sample("beta",  dist.Normal(0.0, 1.0).expand([K]))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
    mu = alpha + jnp.dot(X, beta)
    numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

nuts = NUTS(model)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False)
mcmc.run(jax.random.PRNGKey(0), X=jnp.array(X), y=jnp.array(y))

samples = mcmc.get_samples(group_by_chain=False)
# samples["beta"] shape: (num_draws, K)
print({k: (v.mean(0), v.std(0)) for k, v in samples.items() if k in ["alpha","sigma"]})
print("beta mean:", samples["beta"].mean(0))
print("beta sd  :", samples["beta"].std(0))
/tmp/ipykernel_8259/2230612519.py:16: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False)
{'alpha': (Array(0.9722434, dtype=float32), Array(0.05137558, dtype=float32)), 'sigma': (Array(0.710881, dtype=float32), Array(0.03627468, dtype=float32))}
beta mean: [ 0.43488577 -1.2650281   0.336224  ]
beta sd  : [0.05001275 0.05104848 0.05170776]

PyMC#

NumPyroと近い書き味だが、サンプラーを他のものにすることもできる

特徴:MCMCサンプラーを選べる#

  • Python NUTS sampler (デフォルト、NumPyroより速いことも)

  • NumPyro JAX NUTS sampler

  • BlackJAX NUTS sampler(大規模データで速いらしい)

  • Nutpie NUTS sampler(Rustで書かれていてJAXくらい速いらしい)

pm.sample(nuts_sampler="blackjax")

Faster Sampling with JAX and Numba — PyMC example gallery

import pymc as pm
import numpy as np

with pm.Model() as m:
    alpha = pm.Normal("alpha", mu=0.0, sigma=5.0)
    beta  = pm.Normal("beta",  mu=0.0, sigma=2.0, shape=K)
    sigma = pm.HalfNormal("sigma", sigma=2.0)

    mu = alpha + pm.math.dot(X, beta)
    y_obs = pm.Normal("y", mu=mu, sigma=sigma, observed=y)

    idata = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        random_seed=0,
        target_accept=0.8,
    )

# ArviZでまとめて比較しやすい
import arviz as az
display(az.summary(idata, var_names=["alpha","beta","sigma"]))
Initializing NUTS using jitter+adapt_diag...
/home/runner/work/notes/notes/.venv/lib/python3.10/site-packages/pytensor/link/c/cmodule.py:2968: UserWarning: PyTensor could not link to a BLAS installation. Operations that might benefit from BLAS will be severely degraded.
This usually happens when PyTensor is installed via pip. We recommend it be installed via conda/mamba/pixi instead.
Alternatively, you can use an experimental backend such as Numba or JAX that perform their own BLAS optimizations, by setting `pytensor.config.mode == 'NUMBA'` or passing `mode='NUMBA'` when compiling a PyTensor function.
For more options and details see https://pytensor.readthedocs.io/en/latest/troubleshooting.html#how-do-i-configure-test-my-blas-library
  warnings.warn(
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [alpha, beta, sigma]
/home/runner/.local/share/uv/python/cpython-3.10.19-linux-x86_64-gnu/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
/home/runner/.local/share/uv/python/cpython-3.10.19-linux-x86_64-gnu/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha 0.975 0.051 0.879 1.071 0.001 0.001 5105.0 3164.0 1.0
beta[0] 0.437 0.050 0.343 0.529 0.001 0.001 5539.0 3541.0 1.0
beta[1] -1.267 0.051 -1.363 -1.172 0.001 0.001 7651.0 3447.0 1.0
beta[2] 0.338 0.052 0.243 0.437 0.001 0.001 5717.0 3368.0 1.0
sigma 0.712 0.038 0.646 0.785 0.001 0.001 5299.0 3283.0 1.0