Explainable Boosting Machine / GA2M#

Explainable Boosting Machine (EBM) はtree-basedの一般化加法モデル(generalized additive model: GAM)を改良したもので、特徴は

  • GAM由来の解釈性の高さ

  • GAMより高精度:交互作用項を含む & shape function/base learnerがGBDT

  • 交互作用項の推定を高速化する工夫がとられている

なおEBM(Nori et al., 2019)はGA\(^2\)M(Lou et al., 2013)の実装を並列可能にして高速化したもの。

モデル#

Generalized Additive Model plus Interactions (GA\(^2\)M)

\[ g(E[y])=\beta_0+\sum f_i(x_i)+\sum f_{i, j}(x_i, x_j) \]
  • \(g(\cdot)\):link function。回帰や分類の設定にモデルを適用するための関数

  • \(x_i\)\(i\)番目の特徴量

  • \(f_i(\cdot)\):特徴量ごとの関数(feature function / shape function)

一般化線形モデル(Generalized Linear Model: GLM)

\[ g(E[y]) = \beta_0 + \sum \beta_i x_i \]

線形回帰モデルなどを一般化した一般化線形モデルは各特徴量からの寄与度\(\beta_i\) が明示的で、解釈性が高い。しかし線形モデルゆえにモデルの柔軟性が低く、非線形な構造のデータへの説明性に難がある。

一般化加法モデル(generalized additive model: GAM, Hastie & Tibshirani 1986

そこで線形回帰モデルを一般化したGAMが提案された。

\[ g(E[y])= \beta_0 + \sum f_i(x_i) \]

しかしGAMには交互作用項がないため、2変数間の交互作用を捉えられない(例えば不動産価格推定で、「緯度x経度」という交互作用は立地を表すのに有効だが、GAMだと考慮できない)。

Generalized Additive Model plus Interactions (GA\(^2\)M, Lou et al., 2013)

そこで、GAMに交互作用項を追加した

\[ g(E[y])=\beta_0+\sum f_i(x_i)+\sum f_{i, j}(x_i, x_j) \]

というモデルを Generalized Additive Model plus Interactions (GA\(^2\)M) という。

GA\(^2\)MやEBMは伝統的なGAMからの進化点として

  1. 特徴関数(feature function)\(f_i\)にbaggingや勾配ブースティングを使用する

  2. 予測に有用な交互作用項 \(f_{i,j}(x_i,x_j)\) の検知をする、「FAST」というアルゴリズムを備える

という特徴がある。

\(GA^2M\)のアルゴリズム#

GAMの主要なアルゴリズムには

  1. backfitting

  2. gradient boosting

がある。\(GA^2M\)ではgradient boostingを採用し、勾配(残差など)を近似するように学習させる

記号の意味(notation)

以下で使う記号をまとめておく

  • \(n\):特徴量の次元数

  • \(u \subseteq \{ 1, \dots, n \}\):特徴量のインデックスの集合

    • \(\boldsymbol{x} = (x_1, \dots, x_n)\):特徴量ベクトル

    • \(\boldsymbol{x}_u\)\(u\)に含まれるインデックスの特徴量だけの特徴量ベクトル

  • \(\mathcal{U}\):すべての特徴量とそのペアのインデックス集合

    • \(\mathcal{U}^1=\{\{i\} \mid 1 \leq i \leq n\}\)

    • \(\mathcal{U}^2=\{\{i, j\} \mid 1 \leq i<j \leq n\}\)

    • \(\mathcal{U}=\mathcal{U}^1 \cup \mathcal{U}^2\)

  • \(\mathcal{H}_u\):ルベーグ可測な関数\(f_u(x_u)\)のヒルベルト空間

    • \(\mathcal{H}^1=\sum_{u \in \mathcal{U}^1} \mathcal{H}_u\): 単変量成分に対して加法形 \(F(\boldsymbol{x})=\sum_{u \in \mathcal{U}^1} f_u\left(x_u\right)\) をもつ関数のヒルベルト空間

    • なお、これらの成分 \(f_u\)shape functions と呼ぶ。

Two-stage Construction#

GA\(^2\)M全体の構築は2段階に行う。

Two-stage Construction

  1. Stage1:1次元の成分\(f(x_i)\)のみで最良の加法モデル\(F\in \mathcal{H}^1\)を構築する

  2. Stage2:1次元の関数たち\(\sum f(x_i)\)を固定したもとで、残差に対して交互作用項モデル\(\sum f(x_i, x_j)\)を構築

交互作用項モデル\(\sum f(x_i, x_j)\)はすべての組み合わせを入れるのではなく、予測性能の高い交互作用項のみで行う。予測性能の評価は計算量が高くなりがちで難しい問題なので、FAST(提案手法)で高速な近似計算を行う。

交互作用検出アルゴリズム#

2次の交互作用項は、その組み合わせ数の多さから計算量が非常に大きくなる。

また、有効な交互作用項の検出・選別方法については先行研究がいくつかあるが、一度モデルをfittingさせる必要があるものが多く、計算量や精度の課題がある。

先行研究#

手法

概要

問題点

ANOVA

全ペアワイズ交互作用項を含む加法モデルをフィット後、分散分析で有意性を検定

フルモデルの計算コストが膨大

偏依存関数(PDF)

Friedman & Popescu提案の統計量 \(H^2_{ij}\) で交互作用強度を測定

低密度領域で偽の交互作用を検出する可能性

GUIDE

カイ二乗検定による交互作用検出

FASTより検出力が低い

Grove

制約モデルと非制約モデルのRMSE差で交互作用強度を定量化

計算コストが極めて高い(数日単位)

合成データによる実験では Grove と ANOVA が最も正確で、提案手法 のFAST はそれにほぼ匹敵する検出精度だった

FAST(提案手法)#

概要 / naive fast#

交互作用項 \(f_{ij}(x_i,x_j)\) を(GBDT等で)完全に推定するのは計算コストが高い。そのため、 きわめてシンプルなtreeモデル\(T_{ij}\)で交互作用項の効果を概算する。

  1. まず、2つの変数\(x_i,x_j\)のそれぞれの軸に平行なcutoff点 \(c_i, c_j\)をつくり、\(x_i, x_j\)の2次元の特徴空間を4象限に区切る。

  2. 各象限内のデータ点の平均をとって予測値とする

  3. すべての可能な\((c_i, c_j)\)の点について残差平方和 \(RSS\)が最小化される最適な\(T_{ij}\)を選択する

Construction Predictors#

このtreeの学習・評価をそのまま行うと、すべての切断点\((c_i, c_j)\)について木\(T_{ij}\)を構築する必要があり、計算量が高い。そこで、

  1. 特徴量をヒストグラムにして探索を効率化

  2. ヒストグラムのもとで各点\((c_i, c_j)\)に対する情報を事前に計算しておき、各木\(T_{ij}\)の構築を高速化

する。

変数\(x_i\)がとりうる値のソート済み集合を \(dom(x_i) = {v_1,...,v_{d_i}}\) とする。ここで\(d_i = |dom(x_i)|\)

\(x_i=v\)のときの目的変数の総和を\(H_i^t(v)\)を、重み(あるいはサンプル数)の総和を\(H_i^w(v)\)とおく。直感的には、これらは回帰木におけるヒストグラムである。同様に \(H_{ij}^t(u,v)\)および\(H_{ij}^w(u,v)\) を、\((x_i,x_j)=(u,v)\) のときの目的変数の総和および重みの総和として定義する。

また、累積ヒストグラムを \(C H_i^t(v)\)\(CH_i^w(v)\)とおく。

  • \(C H_i^t(v)=\sum_{u \leq v} H_i^t(u)\)

  • \(C H_i^w(v)=\sum_{u \leq v} H_i^w(u)\)

さらに、次のように定義する。

  • \(\overline{C H_i^t}(v)=\sum_{u>v} H_i^t(u) = CH_i^t(v_i^{d_i})-CH_i^t(v)\)

  • \(\overline{C H_i^w}(v)=\sum_{u>v} H_i^w(u) = CH_i^w(v_i^{d_i})-CH_i^w(v)\)

cutoff \((c_i, c_j)\) で区切った4象限 \(a,b,c,d\)を計算する。

赤い領域\(a\)がわかれば、他の象限も周辺累積ヒストグラム\(CH\)を使って簡単に計算できる。\(a\)は動的計画法を用いて高速に計算できる。

切断点 \((c_i, c_j)\) における目的変数の総和のlookup tableを \(L^t(c_i, c_j)=[a, b, c, d]\) とする。同様に重みの総和のlookup tableを \(L^w\left(c_i, c_j\right)=[a, b, c, d]\) とする。

Lookup Tableの作り方を詳しく述べたのがAlgorithm 2である。
切断点\((p, q)\)における値を\(a[p][q]\)とする。この1行目\(a[1][q]\)を先に計算してから他の行を計算する。

Algorithm 2 ConstructLookupTable

  1. \(\operatorname{sum} \leftarrow 0\)(初期化)

  2. for \(q=1\) to \(d_j\) do(Lookup Tableの1行目を計算) 3. \(\operatorname{sum} \leftarrow \operatorname{sum}+H_{i j}^t\left(v_i^1, v_j^q\right)\) 4. \(a[1][q] \leftarrow\) sum 5. \(L\left(v_i^1, v_j^q\right) \leftarrow \operatorname{ComputeValues} \left(C H_i^t, C H_j^t, a[1][q]\right)\)

  3. for \(p=2\) to \(d_i\) do 7. \(\operatorname{sum} \leftarrow 0\) 8. for \(q=1\) to \(d_j\) do 9. \(\operatorname{sum} \leftarrow \operatorname{sum} +H_{i j}^t\left(v_i^p, v_j^q\right)\) 10. \(a[p][q] \leftarrow \operatorname{sum}+a[p-1][q]\) 11. \(L\left(v_i^p, v_j^q\right) \leftarrow \operatorname{ComputeValues} \left(C H_i^t, C H_j^t, a[p][q]\right)\)

一度\(L^t, L^w\)が求まれば、任意の切断点\((c_i, c_j)\)について木\(T_{ij}\)が計算できる。例えば木\(T_{ij}\)で最も左の葉ノード(4象限\(\{a,b,c,d\}\)のうち\(a\))の値は\( L^t(c_i, c_j).a \ / L^w(c_i, c_j).a\) で求められる。LookupTableからの値の取得と1回の計算程度なので計算量は \(O(1)\) となる

Calculating RSS#

\(T_{ij}\) に対する \(RSS\) の計算は非常に効率的に行える。

まず、\(RSS\) の定義を考える。\(T_{ij}.r\) を領域 \(r\) における予測値とする(ただし \(r \in \{a, b, c, d\}\))。

\[\begin{split} \begin{aligned} RSS &= \sum_{k=1}^{N} \left(y_k - T_{ij}(x_k)\right)^2\\ &= \left( \sum_{k=1}^{N} y_k^2 \;-\; 2 \sum_{r} T_{ij}.r \, L^{t}.r \;+\; \sum_{r} (T_{ij}.r)^2 \, L^{w}.r \right) \end{aligned} \end{split}\]

\(RSS\) の絶対値ではなく相対的な大小関係にのみ関心があるため、実際の実装においては、以下の部分だけを考えればよい:

\[ \sum_{r} (T_{ij}.r)^2 L^{w}.r \;-\; 2 \sum_{r} T_{ij}.r \, L^{t}.r \]

このことから、\(T_{ij}\) に対する \(RSS\) の計算量は \(O(1)\) であることが容易に分かる。

FASTの計算量#
  • 時間計算量は1つのペア \((x_i, x_j)\) あたり \(O(d_i d_j + N)\)

    • 各ペア \((x_i, x_j)\) に対して、ヒストグラムおよび累積ヒストグラムを計算するにはデータ全体を走査する必要があるため、その計算量は \(O(N)\)

    • またLookup Table の構築には \(O(d_i d_j + N)\) の時間計算量が必要なため

  • 空間計算量は \(O(d_i d_j)\)

    • 各ペアごとに \(d_i \times d_j\) の行列を保持する必要があるため

特徴量の離散化(binning)#

連続値特徴量の場合、\(d_i d_j\) は非常に大きくなる可能性がある。しかし、特徴量を \(b\) 個の等頻度ビン(equi-frequency bins)に離散化することで、計算量はペアあたり \(O(b^2 + N)\) に削減できる。GA2Mでは \(b = 256\) としている。

このような離散化は通常、回帰木の性能を損なわないことが先行研究で知られている。また実験により FAST は幅広い \(b\) の値に対して感度が低いことがわかっている。

各項の重み / Feature Importance#

解釈性を高めるために、仮説空間 \(\mathcal{H}\) における最適なモデルを学習した後、すべての項(1次元および2次元の成分)の重要度をランキングし、各項に重みをつける。

本研究では、項 \(u\) に対して、\(f_u\)(ただし \(E[f_u] = 0\))の標準偏差である

\[ \sqrt{E[f_u^2]} \]

を重みとして用いる。

これは線形モデルにおける重みの自然な一般化である。というのも、\(f_i(x_i) = w_i x_i\) の場合、特徴量が正規化されて \(E[x_i^2] = 1\) を満たすとき、\(\sqrt{E[f_i^2]}\)\(|w_i|\) に一致するためである。

予測性能#

提案論文での実験では回帰問題・分類問題それぞれ複数のデータセットに対し、 Random Forestと同程度の予測性能 を示した。

実装#

InterpretMLパッケージに実装されている

Explainable Boosting Machine — InterpretML documentation

import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error
from interpret.glassbox import ExplainableBoostingRegressor
from interpret import show

# 1) データロード
data = fetch_california_housing(as_frame=True)
X: pd.DataFrame = data.data
y: pd.Series = data.target  # Median house value

# 2) train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# 3) EBM(回帰)
ebm = ExplainableBoostingRegressor(
    random_state=42,
    # interactions=10,   # 交互作用を入れたいなら(解釈はやや複雑に)
    # max_bins=256,      # 連続特徴のビン数(大きいほど柔軟だが過学習注意)
)

# 4) fit & 評価
ebm.fit(X_train, y_train)

pred = ebm.predict(X_test)
print(f"R^2: {r2_score(y_test, pred):.3f}")
print(f"MAE: {mean_absolute_error(y_test, pred):.3f}")

# 5) Explain(global / local)
global_exp = ebm.explain_global(name="EBM global (California Housing)")
local_exp = ebm.explain_local(X_test.iloc[:5], y_test.iloc[:5], name="EBM local (first 5)")

# Notebook環境なら可視化(ブラウザ/埋め込み表示)
show(global_exp)
show(local_exp)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[1], line 26
     19 ebm = ExplainableBoostingRegressor(
     20     random_state=42,
     21     # interactions=10,   # 交互作用を入れたいなら(解釈はやや複雑に)
     22     # max_bins=256,      # 連続特徴のビン数(大きいほど柔軟だが過学習注意)
     23 )
     25 # 4) fit & 評価
---> 26 ebm.fit(X_train, y_train)
     28 pred = ebm.predict(X_test)
     29 print(f"R^2: {r2_score(y_test, pred):.3f}")

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/interpret/glassbox/_ebm/_ebm.py:1179, in EBMModel.fit(self, X, y, sample_weight, bags, init_score)
   1176     stop_flag = None
   1177     shm_name = None
-> 1179 results = parallel(
   1180     delayed(booster)(
   1181         shm_name=shm_name,
   1182         bag_idx=idx,
   1183         callback=callback,
   1184         dataset=(
   1185             shared.name if shared.name is not None else shared.dataset
   1186         ),
   1187         intercept_rounds=n_intercept_rounds,
   1188         intercept_learning_rate=develop.get_option(
   1189             "intercept_learning_rate"
   1190         ),
   1191         intercept=bagged_intercept[idx],
   1192         bag=internal_bags[idx],
   1193         # TODO: instead of making these copies we should
   1194         # put init_score into the native shared dataframe
   1195         init_scores=(
   1196             init_score
   1197             if (
   1198                 init_score is None
   1199                 or internal_bags[idx] is None
   1200                 or np.count_nonzero(internal_bags[idx])
   1201                 == len(internal_bags[idx])
   1202             )
   1203             else init_score[internal_bags[idx] != 0]
   1204         ),
   1205         term_features=term_features,
   1206         smoothing_rounds=smoothing_rounds,
   1207         # if there are no validation samples, turn off early stopping
   1208         # because the validation metric cannot improve each round
   1209         early_stopping_rounds=(
   1210             early_stopping_rounds
   1211             if (
   1212                 internal_bags[idx] is not None
   1213                 and (internal_bags[idx] < 0).any()
   1214             )
   1215             else 0
   1216         ),
   1217         rng=rngs[idx],
   1218     )
   1219     for idx in range(self.outer_bags)
   1220 )
   1222 best_iteration = [[]]
   1223 models = []

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:2072, in Parallel.__call__(self, iterable)
   2066 # The first item from the output is blank, but it makes the interpreter
   2067 # progress until it enters the Try/Except block of the generator and
   2068 # reaches the first `yield` statement. This starts the asynchronous
   2069 # dispatch of the tasks to the workers.
   2070 next(output)
-> 2072 return output if self.return_generator else list(output)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:1682, in Parallel._get_outputs(self, iterator, pre_dispatch)
   1679     yield
   1681     with self._backend.retrieval_context():
-> 1682         yield from self._retrieve()
   1684 except GeneratorExit:
   1685     # The generator has been garbage collected before being fully
   1686     # consumed. This aborts the remaining tasks if possible and warn
   1687     # the user if necessary.
   1688     self._exception = True

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:1800, in Parallel._retrieve(self)
   1789 if self.return_ordered:
   1790     # Case ordered: wait for completion (or error) of the next job
   1791     # that have been dispatched and not retrieved yet. If no job
   (...)
   1795     # control only have to be done on the amount of time the next
   1796     # dispatched job is pending.
   1797     if (nb_jobs == 0) or (
   1798         self._jobs[0].get_status(timeout=self.timeout) == TASK_PENDING
   1799     ):
-> 1800         time.sleep(0.01)
   1801         continue
   1803 elif nb_jobs == 0:
   1804     # Case unordered: jobs are added to the list of jobs to
   1805     # retrieve `self._jobs` only once completed or in error, which
   (...)
   1811     # timeouts before any other dispatched job has completed and
   1812     # been added to `self._jobs` to be retrieved.

KeyboardInterrupt: 

信頼区間 / 予測区間の推定#

[2601.18857v1] Statistical Inference for Explainable Boosting Machines はEBMの信頼区間や予測区間を行う方法を提案。

参考文献#