ベイズ線形回帰#

import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
# データを作成
n = 1000

from scipy.stats import multivariate_normal
mean = np.array([3, 5])
Sigma = np.array([
    [1, 0.5],
    [0.5, 2],
])
X = multivariate_normal.rvs(mean=mean, cov=Sigma, size=n, random_state=0)

import statsmodels.api as sm
X = sm.add_constant(X)

# 真のパラメータ
beta = np.array([2, 3, 4])

データが均一分散の場合#

# 均一分散の場合
e = np.random.normal(loc=0, scale=1, size=n)
y = X @ beta + e
# 頻度主義
import statsmodels.api as sm
ols = sm.OLS(y, X).fit(cov_type="HC1")
ols.summary()
OLS Regression Results
Dep. Variable: y R-squared: 0.980
Model: OLS Adj. R-squared: 0.980
Method: Least Squares F-statistic: 2.308e+04
Date: Fri, 29 Nov 2024 Prob (F-statistic): 0.00
Time: 04:53:35 Log-Likelihood: -1424.0
No. Observations: 1000 AIC: 2854.
Df Residuals: 997 BIC: 2869.
Df Model: 2
Covariance Type: HC1
coef std err z P>|z| [0.025 0.975]
const 1.8841 0.139 13.528 0.000 1.611 2.157
x1 2.9957 0.035 84.932 0.000 2.927 3.065
x2 4.0267 0.025 161.028 0.000 3.978 4.076
Omnibus: 1.129 Durbin-Watson: 2.011
Prob(Omnibus): 0.569 Jarque-Bera (JB): 1.196
Skew: 0.076 Prob(JB): 0.550
Kurtosis: 2.926 Cond. No. 25.9


Notes:
[1] Standard Errors are heteroscedasticity robust (HC1)
import pymc as pm
import arviz as az

model = pm.Model()
with model:
    beta0 = pm.Normal("beta0", mu=0, sigma=1)
    beta1 = pm.Normal("beta1", mu=0, sigma=1)
    beta2 = pm.Normal("beta2", mu=0, sigma=1)
    sigma = pm.HalfNormal("sigma", sigma=1)  # 分散なので非負の分布を使う

    # 平均値 mu
    mu = beta0 + beta1 * X[:, 1] + beta2 * X[:, 2]
    # 観測値をもつ確率変数は_obsとする
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

# モデルをGraphvizで表示
pm.model_to_graphviz(model)
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
../../_images/1ac456596cfbbf6f91b5660732d86b6542d5151e74bb110d1d57aae65e4f91df.svg
# ベイズ線形回帰モデルをサンプリング
with model:
    idata = pm.sample(
        chains=2,
        tune=1000, # バーンイン期間の、捨てるサンプル数
        draws=2000, # 採用するサンプル数
        random_seed=0,
    )

# 各chainsの結果を表示
az.plot_trace(idata, figsize=[4, 4])
plt.tight_layout()
plt.show()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [beta0, beta1, beta2, sigma]

Sampling 2 chains for 1_000 tune and 2_000 draw iterations (2_000 + 4_000 draws total) took 8 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
../../_images/5a2e3e4486042e812fde5dc6284d59bdf84757df46492684157b7551d7237f60.png
az.plot_posterior(idata)
plt.show()
../../_images/9a20e1d5857fca31347e686c6a13965e7a819424c638c2f7805ae924df038110.png

データが不均一分散の場合#

# 不均一分散の場合
def normalize(x):
    return (x - x.min()) / (x.max() - x.min())

sigma = 1 + normalize(X[:, 1] + X[:, 2]) * 3
e = np.random.normal(loc=0, scale=sigma, size=n)
y = X @ beta + e

頻度主義 & 不均一分散に頑健な誤差推定#

# 頻度主義
import statsmodels.api as sm
ols = sm.OLS(y, X).fit(cov_type="HC1")
ols.summary()
OLS Regression Results
Dep. Variable: y R-squared: 0.888
Model: OLS Adj. R-squared: 0.888
Method: Least Squares F-statistic: 4000.
Date: Fri, 29 Nov 2024 Prob (F-statistic): 0.00
Time: 04:54:10 Log-Likelihood: -2347.6
No. Observations: 1000 AIC: 4701.
Df Residuals: 997 BIC: 4716.
Df Model: 2
Covariance Type: HC1
coef std err z P>|z| [0.025 0.975]
const 1.7957 0.306 5.878 0.000 1.197 2.395
x1 3.0517 0.089 34.286 0.000 2.877 3.226
x2 4.0183 0.064 62.614 0.000 3.893 4.144
Omnibus: 9.444 Durbin-Watson: 2.183
Prob(Omnibus): 0.009 Jarque-Bera (JB): 10.945
Skew: 0.151 Prob(JB): 0.00420
Kurtosis: 3.414 Cond. No. 25.9


Notes:
[1] Standard Errors are heteroscedasticity robust (HC1)

↑ 切片の推定にバイアスが入っている

均一分散を想定したベイズ線形回帰#

import pymc as pm
import arviz as az

model = pm.Model()
with model:
    beta0 = pm.Normal("beta0", mu=0, sigma=1)
    beta1 = pm.Normal("beta1", mu=0, sigma=1)
    beta2 = pm.Normal("beta2", mu=0, sigma=1)
    sigma = pm.HalfNormal("sigma", sigma=1)  # 分散なので非負の分布を使う

    # 平均値 mu
    mu = beta0 + beta1 * X[:, 1] + beta2 * X[:, 2]
    # 観測値をもつ確率変数は_obsとする
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

# モデルをGraphvizで表示
pm.model_to_graphviz(model)
../../_images/1ac456596cfbbf6f91b5660732d86b6542d5151e74bb110d1d57aae65e4f91df.svg
# ベイズ線形回帰モデルをサンプリング
with model:
    idata = pm.sample(
        chains=2,
        tune=1000, # バーンイン期間の、捨てるサンプル数
        draws=2000, # 採用するサンプル数
        random_seed=0,
    )

# 各chainsの結果を表示
az.plot_trace(idata, figsize=[4, 4])
plt.tight_layout()
plt.show()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [beta0, beta1, beta2, sigma]

Sampling 2 chains for 1_000 tune and 2_000 draw iterations (2_000 + 4_000 draws total) took 8 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
../../_images/6b02ef902bb3ec1c95d6af321865695198f8c6d77839db2aca26c4ba7e7d04e1.png
az.plot_posterior(idata)
plt.show()
../../_images/c40d63e39b54798feaeee81d6e98af1b6607042d5e4bbd4744fa2aec61bae144.png

不均一分散を想定したベイズ線形回帰(WIP)#

分散をxの関数にしたかった。以下コードで推定できるが個々の\(\sigma_i\)が別々に推定される形になって結果が見づらい。もっといい表し方はないものか。

import pymc as pm
import arviz as az

model = pm.Model()
with model:
    beta0 = pm.Normal("beta0", mu=0, sigma=1)
    beta1 = pm.Normal("beta1", mu=0, sigma=1)
    beta2 = pm.Normal("beta2", mu=0, sigma=1)

    # 誤差分散にも線形モデルを入れる
    w0 = pm.Normal("w0", mu=0, sigma=1)
    w1 = pm.Normal("w1", mu=0, sigma=1)
    w2 = pm.Normal("w2", mu=0, sigma=1)
    lam = pm.math.exp(w0 + w1 * X[:, 1] + w2 * X[:, 2])
    sigma = pm.Exponential("sigma", lam=lam)  # 分散なので非負の分布を使う

    # 平均値 mu
    mu = beta0 + beta1 * X[:, 1] + beta2 * X[:, 2]
    # 観測値をもつ確率変数は_obsとする
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

# モデルをGraphvizで表示
pm.model_to_graphviz(model)

# ベイズ線形回帰モデルをサンプリング
with model:
    idata = pm.sample(
        chains=2,
        tune=1000, # バーンイン期間の、捨てるサンプル数
        draws=2000, # 採用するサンプル数
        random_seed=0,
    )

# 各chainsの結果を表示
az.plot_trace(idata, figsize=[4, 4])
plt.tight_layout()
plt.show()

az.plot_posterior(idata)
plt.show()