Bayesian Inference (BI)パッケージ#
Bayesian Inference (BI) は Numpyro等のPythonパッケージの上に作られた、使いやすいインターフェースのパッケージ。R版、Julia版もある。
GPU対応しておりStanより早いのが売り
The Bayesian Inference library for Python R and Julia | bioRxiv
紹介論文。CPUでもStanより速いという実験結果がのっている
単回帰モデルの例#
Univariate Linear Regression – Bayesian Inference (BI)
from BI import bi
# Setup device------------------------------------------------
m = bi(platform='cpu')
# Import Data & Data Manipulation ------------------------------------------------
# Import
from importlib.resources import files
data_path = m.load.howell1(only_path = True)
m.data(data_path, sep=';')
m.df = m.df[m.df.age > 18] # Subset data to adults
m.scale(['weight']) # Normalize
jax.local_device_count 4
| height | weight | age | male | |
|---|---|---|---|---|
| 0 | 151.765 | 0.430669 | 63.0 | 1 |
| 1 | 139.700 | -1.326018 | 63.0 | 0 |
| 2 | 136.525 | -2.041868 | 65.0 | 0 |
| 3 | 156.845 | 1.238745 | 41.0 | 1 |
| 4 | 145.415 | -0.583818 | 51.0 | 0 |
| ... | ... | ... | ... | ... |
| 534 | 162.560 | 0.307701 | 27.0 | 0 |
| 537 | 142.875 | -1.672963 | 31.0 | 0 |
| 540 | 162.560 | 1.102602 | 31.0 | 1 |
| 541 | 156.210 | 1.396847 | 21.0 | 0 |
| 543 | 158.750 | 1.159694 | 68.0 | 1 |
346 rows × 4 columns
# Define model ------------------------------------------------
def model(weight, height):
a = m.dist.normal(178, 20, name = 'a')
b = m.dist.log_normal(0, 1, name = 'b')
s = m.dist.uniform(0, 50, name = 's')
m.dist.normal(a + b * weight , s, obs = height)
# Run mcmc ------------------------------------------------
m.fit(model) # Optimize model parameters through MCMC sampling
# Summary ------------------------------------------------
m.summary() # Get posterior distributions
arviz - WARNING - Shape validation failed: input_shape: (1, 500), minimum_shape: (chains=2, draws=4)
| mean | sd | hdi_5.5% | hdi_94.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| a | 154.65 | 0.28 | 154.19 | 155.05 | 0.01 | 0.01 | 465.66 | 407.28 | NaN |
| b | 5.78 | 0.30 | 5.28 | 6.22 | 0.02 | 0.01 | 329.40 | 314.86 | NaN |
| s | 5.17 | 0.20 | 4.87 | 5.48 | 0.01 | 0.01 | 415.27 | 295.66 | NaN |