PyMC5#

ベイズ推定用のライブラリ

Home — PyMC project website

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm

print(f"Running on PyMC v{pm.__version__}")
Running on PyMC v5.25.1
# コンテキストを作る
model = pm.Model()
with model:
    x = pm.Binomial("x", p=0.5, n=5)
x
\[\text{x} \sim \operatorname{Binomial}(5,~0.5)\]
# コンテキスト(with句)の中でならModelと紐づけられる
with model:
    # 事前分布の予測値を取得
    prior_samples = pm.sample_prior_predictive(random_seed=0, draws=500)
Sampling: [x]
prior_samples
arviz.InferenceData
    • <xarray.Dataset> Size: 8kB
      Dimensions:  (chain: 1, draw: 500)
      Coordinates:
        * chain    (chain) int64 8B 0
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          x        (chain, draw) int64 4kB 4 2 3 1 2 3 1 4 2 3 ... 2 3 2 3 2 2 3 2 2 3
      Attributes:
          created_at:                 2025-12-02T14:37:57.729777+00:00
          arviz_version:              0.21.0
          inference_library:          pymc
          inference_library_version:  5.25.1

# arviz: 可視化ライブラリ
import arviz as az
az.summary(prior_samples)
arviz - WARNING - Shape validation failed: input_shape: (1, 500), minimum_shape: (chains=2, draws=4)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
x 2.482 1.142 0.0 4.0 0.05 0.03 513.0 441.0 NaN
import numpy as np
x_samples: np.array = prior_samples["prior"]["x"].values
az.plot_dist(x_samples)
<Axes: >
../../_images/e0cd2657997af803465e82524c5f180b267ea8c6a7e14193105249bf2cbb8aa6.png

モデルの定義とグラフ表記#

import pymc as pm
import numpy as np

# 観測値
X = np.array([1, 0, 0, 1, 0])

model = pm.Model()
with model:
    # パラメータpが一様分布に従うと定義
    p = pm.Uniform("p", lower=0.0, upper=1.0)
    # 観測値Xがベルヌーイ分布に従うと定義
    X_obs = pm.Bernoulli("X_obs", p=p, observed=X)

# モデルをGraphvizで表示
pm.model_to_graphviz(model)
../../_images/10b42c20260f954a807be30cc616a17ab9bc4d4e6ea5d16e7ec18e9dc2600db0.svg

MCMC#

with model:
    idata = pm.sample(
        chains=3,
        tune=1000, # バーンイン期間の、捨てるサンプル数
        draws=1000, # 採用するサンプル数
        random_seed=0,
    )
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (3 chains in 2 jobs)
NUTS: [p]

Sampling 3 chains for 1_000 tune and 1_000 draw iterations (3_000 + 3_000 draws total) took 2 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
# 各chainsの結果を表示
az.plot_trace(idata)
plt.tight_layout()
../../_images/1b01480a8e416a3096b7cd0bba500167de94d5304129e968b43e710fe2c3bf2e.png
az.plot_posterior(idata)
<Axes: title={'center': 'p'}>
../../_images/58e35c084b5d821bd2ed53d196e88cd3ae53c7236ab6b6c8fe840c19f5f442f4.png
az.summary(idata)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
p 0.425 0.175 0.11 0.739 0.006 0.002 1008.0 1639.0 1.0