Bayesian Inference (BI)パッケージ

Bayesian Inference (BI)パッケージ#

Bayesian Inference (BI) は Numpyro等のPythonパッケージの上に作られた、使いやすいインターフェースのパッケージ。R版、Julia版もある。

GPU対応しておりStanより早いのが売り

Bayesian Inference (BI)

The Bayesian Inference library for Python R and Julia | bioRxiv

  • 紹介論文。CPUでもStanより速いという実験結果がのっている

単回帰モデルの例#

Univariate Linear Regression – Bayesian Inference (BI)

from BI import bi

# Setup device------------------------------------------------
m = bi(platform='cpu')

# Import Data & Data Manipulation ------------------------------------------------
# Import
from importlib.resources import files
data_path = m.load.howell1(only_path = True)
m.data(data_path, sep=';') 

m.df = m.df[m.df.age > 18] # Subset data to adults
m.scale(['weight']) # Normalize
jax.local_device_count 4
height weight age male
0 151.765 0.430669 63.0 1
1 139.700 -1.326018 63.0 0
2 136.525 -2.041868 65.0 0
3 156.845 1.238745 41.0 1
4 145.415 -0.583818 51.0 0
... ... ... ... ...
534 162.560 0.307701 27.0 0
537 142.875 -1.672963 31.0 0
540 162.560 1.102602 31.0 1
541 156.210 1.396847 21.0 0
543 158.750 1.159694 68.0 1

346 rows × 4 columns

# Define model ------------------------------------------------
def model(weight, height):    
    a = m.dist.normal(178, 20, name = 'a') 
    b = m.dist.log_normal(0, 1, name = 'b') 
    s = m.dist.uniform(0, 50, name = 's') 
    m.dist.normal(a + b * weight , s, obs = height) 

# Run mcmc ------------------------------------------------
m.fit(model)  # Optimize model parameters through MCMC sampling
# Summary ------------------------------------------------
m.summary() # Get posterior distributions
arviz - WARNING - Shape validation failed: input_shape: (1, 500), minimum_shape: (chains=2, draws=4)
mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
a 154.65 0.28 154.19 155.05 0.01 0.01 465.66 407.28 NaN
b 5.78 0.30 5.28 6.22 0.02 0.01 329.40 314.86 NaN
s 5.17 0.20 4.87 5.48 0.01 0.01 415.27 295.66 NaN