分位点回帰#
分位点
と表す。ここで
例えば
標準的な回帰モデルは二乗誤差
分位点回帰 (quantile regression)モデルはpinball loss
pinball lossは
あるいは
と書かれる
Show code cell source
import numpy as np
import matplotlib.pyplot as plt
def pinball_loss(x, tau):
return (tau - 1 * (x <= 0)) * x
x = np.linspace(-1, 1, 100)
fig, axes = plt.subplots(figsize=[10, 2], ncols=3)
for i, tau in enumerate([0.1, 0.5, 0.9]):
y = pinball_loss(x, tau=tau)
axes[i].plot(x, y)
if i == 0:
axes[i].set(title=f"τ={tau}", xlabel=r"$x$", ylabel=r"$y = (\tau - 1(x <= 0)) x$")
else:
axes[i].set(title=f"τ={tau}", xlabel=r"$x$")
fig.show()

なお、pinball lossは
と、絶対誤差と比例する形になる。
絶対誤差の和を目的関数にとった線形モデルは統計学においてleast absolute deviations (LAD) と呼ばれ、その解は条件付き中央値になる
Show code cell source
import numpy as np
import matplotlib.pyplot as plt
def pinball_loss(x, tau):
return (tau - 1 * (x <= 0)) * x
x = np.linspace(-3, 3, 100)
fig, ax = plt.subplots()
ax.plot(x, pinball_loss(x, tau=0.5), label=r"$\rho_{0.5}(x)$")
ax.plot(x, abs(x), label="|x|")
ax.legend()
fig.show()

絶対誤差の最適解
誤差関数
絶対値の中身の符号で場合分けすると
予測損失を微分するとそれぞれの項は
よって
となる。
となる点が予測損失を極小化することがわかる。これは
である。なお、これは累積分布関数
なお、中央値の定義には、以下の式を満たす
というものもある。
定積分の定義
より、この導関数は
Pinball lossの最適解
Pinball Lossを少し表現を変えて
と表すと、さきほどの絶対誤差の場合分けした項に
絶対誤差の場合と同様に導関数は
となる。ここで累積分布関数
とおき、導関数を0とおく。
これを整理すると
となり、累積分布
となる。累積分布関数の逆関数は分位点なので、
モデルの評価#
D2 pinball score#
ここで
この
を代入したものが
interval score#
分位点回帰モデルの実践#
statsmodelsでは quantreg()
で実行できる
Quantile regression - statsmodels 0.15.0 (+213)
Show code cell source
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
data = sm.datasets.engel.load_pandas().data
fig, ax = plt.subplots()
ax.scatter(data["income"], data["foodexp"])
ax.set(xlabel="income", ylabel="foodexp", title="Quantile Linear Regression")
x = np.linspace(data["income"].min(), data["income"].max(), 10)
model = smf.quantreg("foodexp ~ income", data)
for q in [0.1, 0.5, 0.9]:
res = model.fit(q=q)
y_hat = res.predict(pd.DataFrame({"income": x}))
ax.plot(x, y_hat, label=fr"$\tau = {q}$")
ax.legend()
fig.show()
