Explainable Boosting Machine#
Explainable Boosting Machine (EBM) はtree-basedの一般化加法モデル(generalized additive model: GAM)で、
交互作用の自動検知
解釈性の高さ
が特徴
EBM(Nori et al., 2019)はGA\(^2\)M(Lou et al., 2013)の実装を並列可能にして高速化したもの。
モデル#
通常のGAM
に交互作用項を追加した
というモデルを Generalized Additive Model plus Interactions (GA\(^2\)M) という。
なお、ここで\(g(\cdot)\)は回帰や分類の設定にGAMを適用するためのlink functionで、\(x_i\)は\(i\)番目の特徴量、\(f_i(\cdot)\)は特徴量ごとの関数(feature function)
GA\(^2\)MやEBMは伝統的なGAMからの進化点として
特徴関数(feature function)\(f_i\)にbaggingや勾配ブースティングを使用する
交互作用項 \(f_{i,j}(x_i,x_j)\) の自動検知をする、「FAST」というアルゴリズムを備える
という特徴がある。
学習アルゴリズム#
GAMの主要なアルゴリズムには
backfitting
gradient boosting
がある。\(GA^2M\)ではgradient boostingを採用し、勾配(残差など)を近似するように学習させる
Notation#
\(n\):特徴量の次元数
\(u \subseteq \{ 1, \dots, n \}\):特徴量のインデックスの集合
\(\boldsymbol{x} = (x_1, \dots, x_n)\):特徴量ベクトル
\(\boldsymbol{x}_u\):\(u\)に含まれるインデックスの特徴量だけの特徴量ベクトル
\(\mathcal{U}\):すべての特徴量とそのペアのインデックス集合
\(\mathcal{U}^1=\{\{i\} \mid 1 \leq i \leq n\}\)
\(\mathcal{U}^2=\{\{i, j\} \mid 1 \leq i<j \leq n\}\)
\(\mathcal{U}=\mathcal{U}^1 \cup \mathcal{U}^2\)
\(\mathcal{H}_u\):ルベーグ可測な関数\(f_u(x_u)\)のヒルベルト空間
\(\mathcal{H}^1=\sum_{u \in \mathcal{U}^1} \mathcal{H}_u\): 単変量成分に対して加法形 \(F(\boldsymbol{x})=\sum_{u \in \mathcal{U}^1} f_u\left(x_u\right)\) をもつ関数のヒルベルト空間
なお、これらの成分 \(f_u\) を shape functions と呼ぶ。
Algorithm 1 はシンプルでわかりやすいものの高次元データに対してやや効率が悪いアルゴリズム
\(\mathcal{S} \leftarrow \emptyset\) (\(\mathcal{S}\):選択した特徴量のペアの集合)
\(\mathcal{Z} \leftarrow \mathcal{U}^2\) (\(\mathcal{Z}\):残りの特徴量のペア)
while not converge do
\(\displaystyle F \leftarrow \operatorname*{arg min}_{F \in \mathcal{H}_1 + \sum_{u \in \mathcal{S}}\mathcal{H}_u} \frac{1}{2}\mathbb{E}\left[(y - F(\boldsymbol{x}))^2\right]\)
(最適な\(F \in \mathcal{H}_1 + \sum_{u \in \mathcal{S}}\mathcal{H}_u\)を選択)\(R \leftarrow y - F(\boldsymbol{x})\)
(残差\(R\)を計算)for all \(u \in \mathcal{Z}\) do
\(F_u \leftarrow \mathbb{E}[R \mid x_u]\)
(すべての候補について交互作用項を計算)
\(u^* \leftarrow \arg\min_{u \in \mathcal{Z}} \frac{1}{2}\mathbb{E}\left[(R - F_u(x_u))^2\right]\)
\(\mathcal{S} \leftarrow \mathcal{S} \cup \{u^*\}\) (最適な交互作用項のペアを追加)
\(\mathcal{Z} \leftarrow \mathcal{Z} - \{u^*\}\) (候補集合から削除)
Algorithm 1には2つの問題がある
すべての交互作用のペアは \(O(n^2)\) あるので高コスト
ペアを1つ追加するたびにモデル全体を再学習する必要がある
これを近似する
FAST#
交互作用項\(f_{ij}(x_i,x_j)\)を推定するのは計算コストが高いため、きわめてシンプルなtreeモデル\(T_{ij}\)へと簡略化する。
まず、2つの変数\(x_i,x_j\)のそれぞれの軸に平行なcutoff点 \(c_i, c_j\)をつくり、\(x_i, x_j\)の2次元の特徴空間を4象限に区切る。
各象限内のデータ点の平均をとって予測値とする
すべての可能な\((c_i, c_j)\)の点について残差平方和 \(RSS\)が最小化される最適な\(T_{ij}\)を選択する
変数\(x_i\)がとりうる値のソート済み集合を \(dom(x_i) = {v_1,...,v_{d_i}}\) とする。ここで\(d_i = |dom(x_i)|\)。
\(x_i=v\)のときの目的変数の総和を\(H_i^t(v)\)を、重み(あるいはサンプル数)の総和を\(H_i^w(v)\)とおく。直感的には、これらは回帰木におけるヒストグラムである。同様に \(H_{ij}^t(u,v)\)および\(H_{ij}^w(u,v)\) を、\((x_i,x_j)=(u,v)\) のときの目的変数の総和および重みの総和として定義する。
また、累積ヒストグラムを \(C H_i^t(v)\) と \(CH_i^w(v)\)とおく。
\(C H_i^t(v)=\sum_{u \leq v} H_i^t(u)\)
\(C H_i^w(v)=\sum_{u \leq v} H_i^w(u)\)
さらに、次のように定義する。
\(\overline{C H_i^t}(v)=\sum_{u>v} H_i^t(u) = CH_i^t(v_i^{d_i})-CH_i^t(v)\)
\(\overline{C H_i^w}(v)=\sum_{u>v} H_i^w(u) = CH_i^w(v_i^{d_i})-CH_i^w(v)\)
FAST#
交互作用項は、特徴量の数を\(n\)とすると \(O(n^2)\) の組み合わせ・計算量になり、高次元データでは計算コストが高くなる。
FASTは \(O(N+b^2)\)、\(N\)はサンプルサイズで\(b\)は特徴量のbinの数(例えば\(b=256\))にする。
Feature Importance#
モデルのfeature function \(f_i\) は各特徴量が予測にどれくらい寄与したかが明確にわかる
実装#
InterpretMLパッケージに実装されている
Explainable Boosting Machine — InterpretML documentation
import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error
from interpret.glassbox import ExplainableBoostingRegressor
from interpret import show
# 1) データロード
data = fetch_california_housing(as_frame=True)
X: pd.DataFrame = data.data
y: pd.Series = data.target # Median house value
# 2) train/test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 3) EBM(回帰)
ebm = ExplainableBoostingRegressor(
random_state=42,
# interactions=10, # 交互作用を入れたいなら(解釈はやや複雑に)
# max_bins=256, # 連続特徴のビン数(大きいほど柔軟だが過学習注意)
)
# 4) fit & 評価
ebm.fit(X_train, y_train)
pred = ebm.predict(X_test)
print(f"R^2: {r2_score(y_test, pred):.3f}")
print(f"MAE: {mean_absolute_error(y_test, pred):.3f}")
# 5) Explain(global / local)
global_exp = ebm.explain_global(name="EBM global (California Housing)")
local_exp = ebm.explain_local(X_test.iloc[:5], y_test.iloc[:5], name="EBM local (first 5)")
# Notebook環境なら可視化(ブラウザ/埋め込み表示)
show(global_exp)
show(local_exp)
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[1], line 26
19 ebm = ExplainableBoostingRegressor(
20 random_state=42,
21 # interactions=10, # 交互作用を入れたいなら(解釈はやや複雑に)
22 # max_bins=256, # 連続特徴のビン数(大きいほど柔軟だが過学習注意)
23 )
25 # 4) fit & 評価
---> 26 ebm.fit(X_train, y_train)
28 pred = ebm.predict(X_test)
29 print(f"R^2: {r2_score(y_test, pred):.3f}")
File ~/work/notes/notes/.venv/lib/python3.10/site-packages/interpret/glassbox/_ebm/_ebm.py:1179, in EBMModel.fit(self, X, y, sample_weight, bags, init_score)
1176 stop_flag = None
1177 shm_name = None
-> 1179 results = parallel(
1180 delayed(booster)(
1181 shm_name=shm_name,
1182 bag_idx=idx,
1183 callback=callback,
1184 dataset=(
1185 shared.name if shared.name is not None else shared.dataset
1186 ),
1187 intercept_rounds=n_intercept_rounds,
1188 intercept_learning_rate=develop.get_option(
1189 "intercept_learning_rate"
1190 ),
1191 intercept=bagged_intercept[idx],
1192 bag=internal_bags[idx],
1193 # TODO: instead of making these copies we should
1194 # put init_score into the native shared dataframe
1195 init_scores=(
1196 init_score
1197 if (
1198 init_score is None
1199 or internal_bags[idx] is None
1200 or np.count_nonzero(internal_bags[idx])
1201 == len(internal_bags[idx])
1202 )
1203 else init_score[internal_bags[idx] != 0]
1204 ),
1205 term_features=term_features,
1206 smoothing_rounds=smoothing_rounds,
1207 # if there are no validation samples, turn off early stopping
1208 # because the validation metric cannot improve each round
1209 early_stopping_rounds=(
1210 early_stopping_rounds
1211 if (
1212 internal_bags[idx] is not None
1213 and (internal_bags[idx] < 0).any()
1214 )
1215 else 0
1216 ),
1217 rng=rngs[idx],
1218 )
1219 for idx in range(self.outer_bags)
1220 )
1222 best_iteration = [[]]
1223 models = []
File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:2072, in Parallel.__call__(self, iterable)
2066 # The first item from the output is blank, but it makes the interpreter
2067 # progress until it enters the Try/Except block of the generator and
2068 # reaches the first `yield` statement. This starts the asynchronous
2069 # dispatch of the tasks to the workers.
2070 next(output)
-> 2072 return output if self.return_generator else list(output)
File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:1682, in Parallel._get_outputs(self, iterator, pre_dispatch)
1679 yield
1681 with self._backend.retrieval_context():
-> 1682 yield from self._retrieve()
1684 except GeneratorExit:
1685 # The generator has been garbage collected before being fully
1686 # consumed. This aborts the remaining tasks if possible and warn
1687 # the user if necessary.
1688 self._exception = True
File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:1800, in Parallel._retrieve(self)
1789 if self.return_ordered:
1790 # Case ordered: wait for completion (or error) of the next job
1791 # that have been dispatched and not retrieved yet. If no job
(...)
1795 # control only have to be done on the amount of time the next
1796 # dispatched job is pending.
1797 if (nb_jobs == 0) or (
1798 self._jobs[0].get_status(timeout=self.timeout) == TASK_PENDING
1799 ):
-> 1800 time.sleep(0.01)
1801 continue
1803 elif nb_jobs == 0:
1804 # Case unordered: jobs are added to the list of jobs to
1805 # retrieve `self._jobs` only once completed or in error, which
(...)
1811 # timeouts before any other dispatched job has completed and
1812 # been added to `self._jobs` to be retrieved.
KeyboardInterrupt:
統計的推論#
[2601.18857v1] Statistical Inference for Explainable Boosting Machines