Explainable Boosting Machine

Explainable Boosting Machine#

Explainable Boosting Machine (EBM) はtree-basedの一般化加法モデル(generalized additive model: GAM)で、

  • 交互作用の自動検知

  • 解釈性の高さ

が特徴

EBM(Nori et al., 2019)はGA\(^2\)M(Lou et al., 2013)の実装を並列可能にして高速化したもの。

モデル#

通常のGAM

\[ g(E[y])= \beta_0 + \sum f_i(x_i) \]

に交互作用項を追加した

\[ g(E[y])=\beta_0+\sum f_i(x_i)+\sum f_{i, j}(x_i, x_j) \]

というモデルを Generalized Additive Model plus Interactions (GA\(^2\)M) という。

なお、ここで\(g(\cdot)\)は回帰や分類の設定にGAMを適用するためのlink functionで、\(x_i\)\(i\)番目の特徴量、\(f_i(\cdot)\)は特徴量ごとの関数(feature function)

GA\(^2\)MやEBMは伝統的なGAMからの進化点として

  1. 特徴関数(feature function)\(f_i\)にbaggingや勾配ブースティングを使用する

  2. 交互作用項 \(f_{i,j}(x_i,x_j)\) の自動検知をする、「FAST」というアルゴリズムを備える

という特徴がある。

FAST#

交互作用項は、特徴量の数を\(n\)とすると \(O(n^2)\) の組み合わせ・計算量になり、高次元データでは計算コストが高くなる。

FASTは \(O(N+b^2)\)\(N\)はサンプルサイズで\(b\)は特徴量のbinの数(例えば\(b=256\)

Feature Importance#

モデルのfeature function \(f_i\) は各特徴量が予測にどれくらい寄与したかが明確にわかる

実装#

InterpretMLパッケージに実装されている

Explainable Boosting Machine — InterpretML documentation

import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error
from interpret.glassbox import ExplainableBoostingRegressor
from interpret import show

# 1) データロード
data = fetch_california_housing(as_frame=True)
X: pd.DataFrame = data.data
y: pd.Series = data.target  # Median house value

# 2) train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# 3) EBM(回帰)
ebm = ExplainableBoostingRegressor(
    random_state=42,
    # interactions=10,   # 交互作用を入れたいなら(解釈はやや複雑に)
    # max_bins=256,      # 連続特徴のビン数(大きいほど柔軟だが過学習注意)
)

# 4) fit & 評価
ebm.fit(X_train, y_train)

pred = ebm.predict(X_test)
print(f"R^2: {r2_score(y_test, pred):.3f}")
print(f"MAE: {mean_absolute_error(y_test, pred):.3f}")

# 5) Explain(global / local)
global_exp = ebm.explain_global(name="EBM global (California Housing)")
local_exp = ebm.explain_local(X_test.iloc[:5], y_test.iloc[:5], name="EBM local (first 5)")

# Notebook環境なら可視化(ブラウザ/埋め込み表示)
show(global_exp)
show(local_exp)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[1], line 26
     19 ebm = ExplainableBoostingRegressor(
     20     random_state=42,
     21     # interactions=10,   # 交互作用を入れたいなら(解釈はやや複雑に)
     22     # max_bins=256,      # 連続特徴のビン数(大きいほど柔軟だが過学習注意)
     23 )
     25 # 4) fit & 評価
---> 26 ebm.fit(X_train, y_train)
     28 pred = ebm.predict(X_test)
     29 print(f"R^2: {r2_score(y_test, pred):.3f}")

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/interpret/glassbox/_ebm/_ebm.py:1292, in EBMModel.fit(self, X, y, sample_weight, bags, init_score)
   1289 if isinstance(interactions, int):
   1290     _log.info("Estimating with FAST")
-> 1292     bagged_ranked_interaction = parallel(
   1293         # TODO: the combinations below should be selected from the non-excluded features
   1294         delayed(rank_interactions)(
   1295             shm_name=shm_name,
   1296             bag_idx=idx,
   1297             dataset=(
   1298                 shared.name
   1299                 if shared.name is not None
   1300                 else shared.dataset
   1301             ),
   1302             intercept=bagged_intercept[idx],
   1303             bag=internal_bags[idx],
   1304             init_scores=scores_bags[idx],
   1305             iter_term_features=combinations(
   1306                 range(n_features_in), 2
   1307             ),
   1308             exclude=exclude,
   1309             exclude_features=exclude_features,
   1310             calc_interaction_flags=interaction_flags,
   1311             max_cardinality=max_cardinality,
   1312             min_samples_leaf=min_samples_leaf,
   1313             min_hessian=min_hessian,
   1314             reg_alpha=reg_alpha,
   1315             reg_lambda=reg_lambda,
   1316             max_delta_step=max_delta_step,
   1317             create_interaction_flags=(
   1318                 Native.CreateInteractionFlags_DifferentialPrivacy
   1319                 if is_differential_privacy
   1320                 else Native.CreateInteractionFlags_Default
   1321             ),
   1322             objective=objective,
   1323             acceleration=develop.get_option("acceleration"),
   1324             experimental_params=None,
   1325             n_output_interactions=0,
   1326             develop_options=develop._develop_options,
   1327         )
   1328         for idx in range(self.outer_bags)
   1329     )
   1331     # Select merged pairs
   1332     pair_ranks = {}

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:2072, in Parallel.__call__(self, iterable)
   2066 # The first item from the output is blank, but it makes the interpreter
   2067 # progress until it enters the Try/Except block of the generator and
   2068 # reaches the first `yield` statement. This starts the asynchronous
   2069 # dispatch of the tasks to the workers.
   2070 next(output)
-> 2072 return output if self.return_generator else list(output)

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:1682, in Parallel._get_outputs(self, iterator, pre_dispatch)
   1679     yield
   1681     with self._backend.retrieval_context():
-> 1682         yield from self._retrieve()
   1684 except GeneratorExit:
   1685     # The generator has been garbage collected before being fully
   1686     # consumed. This aborts the remaining tasks if possible and warn
   1687     # the user if necessary.
   1688     self._exception = True

File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:1800, in Parallel._retrieve(self)
   1789 if self.return_ordered:
   1790     # Case ordered: wait for completion (or error) of the next job
   1791     # that have been dispatched and not retrieved yet. If no job
   (...)
   1795     # control only have to be done on the amount of time the next
   1796     # dispatched job is pending.
   1797     if (nb_jobs == 0) or (
   1798         self._jobs[0].get_status(timeout=self.timeout) == TASK_PENDING
   1799     ):
-> 1800         time.sleep(0.01)
   1801         continue
   1803 elif nb_jobs == 0:
   1804     # Case unordered: jobs are added to the list of jobs to
   1805     # retrieve `self._jobs` only once completed or in error, which
   (...)
   1811     # timeouts before any other dispatched job has completed and
   1812     # been added to `self._jobs` to be retrieved.

KeyboardInterrupt: 

統計的推論#

[2601.18857v1] Statistical Inference for Explainable Boosting Machines

参考文献#