NGBoost#

Case 1: 線形データ・不均一分散#

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

x = np.linspace(0, 10, 1000)
sigma = np.sqrt(x)
y = norm.rvs(loc=x, scale=sigma, random_state=0)
X = x.reshape(-1, 1)

fig, ax = plt.subplots()
ax.scatter(x, y)
ax.plot(x, x, color="black", alpha=.5, label="mean")
ax.set(xlabel="x", ylabel="y")
ax.legend()
fig.show()
../../_images/36cf40f382177016b922b397c2b7da12a3dac579aa2bdffdd918c3323d43d50d.png
from ngboost import NGBRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

ngb = NGBRegressor().fit(X_train, y_train)
y_pred = ngb.predict(X_test)
y_dist = ngb.pred_dist(X_test)

print('Test MSE', mean_squared_error(y_pred, y_test))

# test Negative Log Likelihood
test_NLL = -y_dist.logpdf(y_test).mean()
print('Test NLL', test_NLL)
[iter 0] loss=2.7161 val_loss=0.0000 scale=1.0000 norm=3.0752
[iter 100] loss=2.1704 val_loss=0.0000 scale=2.0000 norm=3.5366
[iter 200] loss=1.9829 val_loss=0.0000 scale=2.0000 norm=3.3186
[iter 300] loss=1.9095 val_loss=0.0000 scale=2.0000 norm=3.2352
[iter 400] loss=1.8695 val_loss=0.0000 scale=2.0000 norm=3.1663
Test MSE 4.561440687111083
Test NLL 2.26809761464561
fig, ax = plt.subplots()
ax.scatter(x, y, alpha=.5)
ax.plot(x, x, color="black", alpha=.5, label="mean")
ax.set(xlabel="x", ylabel="y")
ax.legend()

X_test = np.sort(X_test, axis=0)
y_dist = ngb.pred_dist(X_test)

alphas = [0.05, 0.01]
colors = ["darkorange", "tomato"]
for alpha, color in zip(alphas, colors):
    upper = norm.ppf(q=1 - (alpha/2), loc=y_dist.params["loc"], scale=y_dist.params["scale"])
    lower = norm.ppf(q=(alpha/2), loc=y_dist.params["loc"], scale=y_dist.params["scale"])
    ax.plot(X_test[:, 0], upper, alpha=.9, color=color, linestyle="--", label=rf"$\alpha$={alpha}")
    ax.plot(X_test[:, 0], lower, alpha=.9, color=color, linestyle="--")

ax.legend()
fig.show()
../../_images/6570dc0ba729d26d67880827c6e237a16d4ebdd8fe5cf9db4a87565d1c4f37f1.png

Case 2: 非線形データ・不均一分散#

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

x = np.linspace(0, 5, 1000)
sigma = (np.sin(x / 1) + 2) * 5
z = 10 + x + x ** 2
y = norm.rvs(loc=z, scale=sigma, random_state=0)
X = x.reshape(-1, 1)

fig, ax = plt.subplots()
ax.scatter(x, y)
ax.plot(x, z, color="black", alpha=.5, label="mean")
ax.set(xlabel="x", ylabel="y")
ax.legend()
fig.show()
../../_images/8fdcb28528dcb8c73d19c448319e229bb63d9273df47a8417146d036407f1098.png
from ngboost import NGBRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

ngb = NGBRegressor().fit(X_train, y_train)
y_pred = ngb.predict(X_test)
y_dist = ngb.pred_dist(X_test)

print('Test MSE', mean_squared_error(y_pred, y_test))

# test Negative Log Likelihood
test_NLL = -y_dist.logpdf(y_test).mean()
print('Test NLL', test_NLL)
[iter 0] loss=4.0892 val_loss=0.0000 scale=2.0000 norm=23.7456
[iter 100] loss=3.7119 val_loss=0.0000 scale=2.0000 norm=16.5191
[iter 200] loss=3.5947 val_loss=0.0000 scale=2.0000 norm=15.8042
[iter 300] loss=3.5436 val_loss=0.0000 scale=2.0000 norm=15.3784
[iter 400] loss=3.5081 val_loss=0.0000 scale=2.0000 norm=15.0400
Test MSE 134.73786073450074
Test NLL 3.9583271061912804
fig, ax = plt.subplots()
ax.scatter(x, y, alpha=.5)
ax.plot(x, x, color="black", alpha=.5, label="mean")
ax.set(xlabel="x", ylabel="y")
ax.legend()

X_test = np.sort(X_test, axis=0)
y_dist = ngb.pred_dist(X_test)

alphas = [0.05, 0.01]
colors = ["darkorange", "tomato"]
for alpha, color in zip(alphas, colors):
    upper = norm.ppf(q=1 - (alpha/2), loc=y_dist.params["loc"], scale=y_dist.params["scale"])
    lower = norm.ppf(q=(alpha/2), loc=y_dist.params["loc"], scale=y_dist.params["scale"])
    ax.plot(X_test[:, 0], upper, alpha=.9, color=color, linestyle="--", label=rf"$\alpha$={alpha}")
    ax.plot(X_test[:, 0], lower, alpha=.9, color=color, linestyle="--")

ax.legend()
fig.show()
../../_images/8547b1c2a842f5b27baac5bdd166074b347224b201b3b9ec593bd6d6c024ec48.png