Explainable Boosting Machine#

Explainable Boosting Machine (EBM) はtree-basedの一般化加法モデル(generalized additive model: GAM)で、

  • 交互作用の自動検知

  • 解釈性の高さ

が特徴

EBM(Nori et al., 2019)はGA\(^2\)M(Lou et al., 2013)の実装を並列可能にして高速化したもの。

モデル#

通常のGAM

\[ g(E[y])= \beta_0 + \sum f_i(x_i) \]

に交互作用項を追加した

\[ g(E[y])=\beta_0+\sum f_i(x_i)+\sum f_{i, j}(x_i, x_j) \]

というモデルを Generalized Additive Model plus Interactions (GA\(^2\)M) という。

なお、ここで\(g(\cdot)\)は回帰や分類の設定にGAMを適用するためのlink functionで、\(x_i\)\(i\)番目の特徴量、\(f_i(\cdot)\)は特徴量ごとの関数(feature function)

GA\(^2\)MやEBMは伝統的なGAMからの進化点として

  1. 特徴関数(feature function)\(f_i\)にbaggingや勾配ブースティングを使用する

  2. 交互作用項 \(f_{i,j}(x_i,x_j)\) の自動検知をする、「FAST」というアルゴリズムを備える

という特徴がある。

学習アルゴリズム#

GAMの主要なアルゴリズムには

  1. backfitting

  2. gradient boosting

がある。\(GA^2M\)ではgradient boostingを採用し、勾配(残差など)を近似するように学習させる

Notation#

  • \(n\):特徴量の次元数

  • \(u \subseteq \{ 1, \dots, n \}\):特徴量のインデックスの集合

    • \(\boldsymbol{x} = (x_1, \dots, x_n)\):特徴量ベクトル

    • \(\boldsymbol{x}_u\)\(u\)に含まれるインデックスの特徴量だけの特徴量ベクトル

  • \(\mathcal{U}\):すべての特徴量とそのペアのインデックス集合

    • \(\mathcal{U}^1=\{\{i\} \mid 1 \leq i \leq n\}\)

    • \(\mathcal{U}^2=\{\{i, j\} \mid 1 \leq i<j \leq n\}\)

    • \(\mathcal{U}=\mathcal{U}^1 \cup \mathcal{U}^2\)

  • \(\mathcal{H}_u\):ルベーグ可測な関数\(f_u(x_u)\)のヒルベルト空間

    • \(\mathcal{H}^1=\sum_{u \in \mathcal{U}^1} \mathcal{H}_u\): 単変量成分に対して加法形 \(F(\boldsymbol{x})=\sum_{u \in \mathcal{U}^1} f_u\left(x_u\right)\) をもつ関数のヒルベルト空間

    • なお、これらの成分 \(f_u\)shape functions と呼ぶ。

Algorithm 1 はシンプルでわかりやすいものの高次元データに対してやや効率が悪いアルゴリズム

Algorithm 1 GA\(^2\)M Framework
  1. \(\mathcal{S} \leftarrow \emptyset\)\(\mathcal{S}\):選択した特徴量のペアの集合)

  2. \(\mathcal{Z} \leftarrow \mathcal{U}^2\)\(\mathcal{Z}\):残りの特徴量のペア)

  3. while not converge do

    1. \(\displaystyle F \leftarrow \operatorname*{arg min}_{F \in \mathcal{H}_1 + \sum_{u \in \mathcal{S}}\mathcal{H}_u} \frac{1}{2}\mathbb{E}\left[(y - F(\boldsymbol{x}))^2\right]\)
      (最適な\(F \in \mathcal{H}_1 + \sum_{u \in \mathcal{S}}\mathcal{H}_u\)を選択)

    2. \(R \leftarrow y - F(\boldsymbol{x})\)
      (残差\(R\)を計算)

    3. for all \(u \in \mathcal{Z}\) do

      1. \(F_u \leftarrow \mathbb{E}[R \mid x_u]\)
        (すべての候補について交互作用項を計算)

  4. \(u^* \leftarrow \arg\min_{u \in \mathcal{Z}} \frac{1}{2}\mathbb{E}\left[(R - F_u(x_u))^2\right]\)

  5. \(\mathcal{S} \leftarrow \mathcal{S} \cup \{u^*\}\) (最適な交互作用項のペアを追加)

  6. \(\mathcal{Z} \leftarrow \mathcal{Z} - \{u^*\}\) (候補集合から削除)

Algorithm 1には2つの問題がある

  1. すべての交互作用のペアは \(O(n^2)\) あるので高コスト

  2. ペアを1つ追加するたびにモデル全体を再学習する必要がある

これを近似する

FAST#

交互作用項\(f_{ij}(x_i,x_j)\)を推定するのは計算コストが高いため、きわめてシンプルなtreeモデル\(T_{ij}\)へと簡略化する。

  1. まず、2つの変数\(x_i,x_j\)のそれぞれの軸に平行なcutoff点 \(c_i, c_j\)をつくり、\(x_i, x_j\)の2次元の特徴空間を4象限に区切る。

  2. 各象限内のデータ点の平均をとって予測値とする

  3. すべての可能な\((c_i, c_j)\)の点について残差平方和 \(RSS\)が最小化される最適な\(T_{ij}\)を選択する

変数\(x_i\)がとりうる値のソート済み集合を \(dom(x_i) = {v_1,...,v_{d_i}}\) とする。ここで\(d_i = |dom(x_i)|\)

\(x_i=v\)のときの目的変数の総和を\(H_i^t(v)\)を、重み(あるいはサンプル数)の総和を\(H_i^w(v)\)とおく。直感的には、これらは回帰木におけるヒストグラムである。同様に \(H_{ij}^t(u,v)\)および\(H_{ij}^w(u,v)\) を、\((x_i,x_j)=(u,v)\) のときの目的変数の総和および重みの総和として定義する。

また、累積ヒストグラムを \(C H_i^t(v)\)\(CH_i^w(v)\)とおく。

  • \(C H_i^t(v)=\sum_{u \leq v} H_i^t(u)\)

  • \(C H_i^w(v)=\sum_{u \leq v} H_i^w(u)\)

さらに、次のように定義する。

  • \(\overline{C H_i^t}(v)=\sum_{u>v} H_i^t(u) = CH_i^t(v_i^{d_i})-CH_i^t(v)\)

  • \(\overline{C H_i^w}(v)=\sum_{u>v} H_i^w(u) = CH_i^w(v_i^{d_i})-CH_i^w(v)\)

FAST#

交互作用項は、特徴量の数を\(n\)とすると \(O(n^2)\) の組み合わせ・計算量になり、高次元データでは計算コストが高くなる。

FASTは \(O(N+b^2)\)\(N\)はサンプルサイズで\(b\)は特徴量のbinの数(例えば\(b=256\))にする。

Feature Importance#

モデルのfeature function \(f_i\) は各特徴量が予測にどれくらい寄与したかが明確にわかる

実装#

InterpretMLパッケージに実装されている

Explainable Boosting Machine — InterpretML documentation

import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error
from interpret.glassbox import ExplainableBoostingRegressor
from interpret import show

# 1) データロード
data = fetch_california_housing(as_frame=True)
X: pd.DataFrame = data.data
y: pd.Series = data.target  # Median house value

# 2) train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# 3) EBM(回帰)
ebm = ExplainableBoostingRegressor(
    random_state=42,
    # interactions=10,   # 交互作用を入れたいなら(解釈はやや複雑に)
    # max_bins=256,      # 連続特徴のビン数(大きいほど柔軟だが過学習注意)
)

# 4) fit & 評価
ebm.fit(X_train, y_train)

pred = ebm.predict(X_test)
print(f"R^2: {r2_score(y_test, pred):.3f}")
print(f"MAE: {mean_absolute_error(y_test, pred):.3f}")

# 5) Explain(global / local)
global_exp = ebm.explain_global(name="EBM global (California Housing)")
local_exp = ebm.explain_local(X_test.iloc[:5], y_test.iloc[:5], name="EBM local (first 5)")

# Notebook環境なら可視化(ブラウザ/埋め込み表示)
show(global_exp)
show(local_exp)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[1], line 26
     19 ebm = ExplainableBoostingRegressor(
     20     random_state=42,
     21     # interactions=10,   # 交互作用を入れたいなら(解釈はやや複雑に)
     22     # max_bins=256,      # 連続特徴のビン数(大きいほど柔軟だが過学習注意)
     23 )
     25 # 4) fit & 評価
---> 26 ebm.fit(X_train, y_train)
     28 pred = ebm.predict(X_test)
     29 print(f"R^2: {r2_score(y_test, pred):.3f}")

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/interpret/glassbox/_ebm/_ebm.py:1179, in EBMModel.fit(self, X, y, sample_weight, bags, init_score)
   1176     stop_flag = None
   1177     shm_name = None
-> 1179 results = parallel(
   1180     delayed(booster)(
   1181         shm_name=shm_name,
   1182         bag_idx=idx,
   1183         callback=callback,
   1184         dataset=(
   1185             shared.name if shared.name is not None else shared.dataset
   1186         ),
   1187         intercept_rounds=n_intercept_rounds,
   1188         intercept_learning_rate=develop.get_option(
   1189             "intercept_learning_rate"
   1190         ),
   1191         intercept=bagged_intercept[idx],
   1192         bag=internal_bags[idx],
   1193         # TODO: instead of making these copies we should
   1194         # put init_score into the native shared dataframe
   1195         init_scores=(
   1196             init_score
   1197             if (
   1198                 init_score is None
   1199                 or internal_bags[idx] is None
   1200                 or np.count_nonzero(internal_bags[idx])
   1201                 == len(internal_bags[idx])
   1202             )
   1203             else init_score[internal_bags[idx] != 0]
   1204         ),
   1205         term_features=term_features,
   1206         smoothing_rounds=smoothing_rounds,
   1207         # if there are no validation samples, turn off early stopping
   1208         # because the validation metric cannot improve each round
   1209         early_stopping_rounds=(
   1210             early_stopping_rounds
   1211             if (
   1212                 internal_bags[idx] is not None
   1213                 and (internal_bags[idx] < 0).any()
   1214             )
   1215             else 0
   1216         ),
   1217         rng=rngs[idx],
   1218     )
   1219     for idx in range(self.outer_bags)
   1220 )
   1222 best_iteration = [[]]
   1223 models = []

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:2072, in Parallel.__call__(self, iterable)
   2066 # The first item from the output is blank, but it makes the interpreter
   2067 # progress until it enters the Try/Except block of the generator and
   2068 # reaches the first `yield` statement. This starts the asynchronous
   2069 # dispatch of the tasks to the workers.
   2070 next(output)
-> 2072 return output if self.return_generator else list(output)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:1682, in Parallel._get_outputs(self, iterator, pre_dispatch)
   1679     yield
   1681     with self._backend.retrieval_context():
-> 1682         yield from self._retrieve()
   1684 except GeneratorExit:
   1685     # The generator has been garbage collected before being fully
   1686     # consumed. This aborts the remaining tasks if possible and warn
   1687     # the user if necessary.
   1688     self._exception = True

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:1800, in Parallel._retrieve(self)
   1789 if self.return_ordered:
   1790     # Case ordered: wait for completion (or error) of the next job
   1791     # that have been dispatched and not retrieved yet. If no job
   (...)
   1795     # control only have to be done on the amount of time the next
   1796     # dispatched job is pending.
   1797     if (nb_jobs == 0) or (
   1798         self._jobs[0].get_status(timeout=self.timeout) == TASK_PENDING
   1799     ):
-> 1800         time.sleep(0.01)
   1801         continue
   1803 elif nb_jobs == 0:
   1804     # Case unordered: jobs are added to the list of jobs to
   1805     # retrieve `self._jobs` only once completed or in error, which
   (...)
   1811     # timeouts before any other dispatched job has completed and
   1812     # been added to `self._jobs` to be retrieved.

KeyboardInterrupt: 

統計的推論#

[2601.18857v1] Statistical Inference for Explainable Boosting Machines

参考文献#