Explainable Boosting Machine#
Explainable Boosting Machine (EBM) はtree-basedの一般化加法モデル(generalized additive model: GAM)で、
交互作用の自動検知
解釈性の高さ
が特徴
EBM(Nori et al., 2019)はGA\(^2\)M(Lou et al., 2013)の実装を並列可能にして高速化したもの。
モデル#
通常のGAM
に交互作用項を追加した
というモデルを Generalized Additive Model plus Interactions (GA\(^2\)M) という。
なお、ここで\(g(\cdot)\)は回帰や分類の設定にGAMを適用するためのlink functionで、\(x_i\)は\(i\)番目の特徴量、\(f_i(\cdot)\)は特徴量ごとの関数(feature function)
GA\(^2\)MやEBMは伝統的なGAMからの進化点として
特徴関数(feature function)\(f_i\)にbaggingや勾配ブースティングを使用する
交互作用項 \(f_{i,j}(x_i,x_j)\) の自動検知をする、「FAST」というアルゴリズムを備える
という特徴がある。
FAST#
交互作用項は、特徴量の数を\(n\)とすると \(O(n^2)\) の組み合わせ・計算量になり、高次元データでは計算コストが高くなる。
FASTは \(O(N+b^2)\)、\(N\)はサンプルサイズで\(b\)は特徴量のbinの数(例えば\(b=256\))
Feature Importance#
モデルのfeature function \(f_i\) は各特徴量が予測にどれくらい寄与したかが明確にわかる
実装#
InterpretMLパッケージに実装されている
Explainable Boosting Machine — InterpretML documentation
import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error
from interpret.glassbox import ExplainableBoostingRegressor
from interpret import show
# 1) データロード
data = fetch_california_housing(as_frame=True)
X: pd.DataFrame = data.data
y: pd.Series = data.target # Median house value
# 2) train/test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 3) EBM(回帰)
ebm = ExplainableBoostingRegressor(
random_state=42,
# interactions=10, # 交互作用を入れたいなら(解釈はやや複雑に)
# max_bins=256, # 連続特徴のビン数(大きいほど柔軟だが過学習注意)
)
# 4) fit & 評価
ebm.fit(X_train, y_train)
pred = ebm.predict(X_test)
print(f"R^2: {r2_score(y_test, pred):.3f}")
print(f"MAE: {mean_absolute_error(y_test, pred):.3f}")
# 5) Explain(global / local)
global_exp = ebm.explain_global(name="EBM global (California Housing)")
local_exp = ebm.explain_local(X_test.iloc[:5], y_test.iloc[:5], name="EBM local (first 5)")
# Notebook環境なら可視化(ブラウザ/埋め込み表示)
show(global_exp)
show(local_exp)
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[1], line 26
19 ebm = ExplainableBoostingRegressor(
20 random_state=42,
21 # interactions=10, # 交互作用を入れたいなら(解釈はやや複雑に)
22 # max_bins=256, # 連続特徴のビン数(大きいほど柔軟だが過学習注意)
23 )
25 # 4) fit & 評価
---> 26 ebm.fit(X_train, y_train)
28 pred = ebm.predict(X_test)
29 print(f"R^2: {r2_score(y_test, pred):.3f}")
File ~/work/notes/notes/.venv/lib/python3.10/site-packages/interpret/glassbox/_ebm/_ebm.py:1292, in EBMModel.fit(self, X, y, sample_weight, bags, init_score)
1289 if isinstance(interactions, int):
1290 _log.info("Estimating with FAST")
-> 1292 bagged_ranked_interaction = parallel(
1293 # TODO: the combinations below should be selected from the non-excluded features
1294 delayed(rank_interactions)(
1295 shm_name=shm_name,
1296 bag_idx=idx,
1297 dataset=(
1298 shared.name
1299 if shared.name is not None
1300 else shared.dataset
1301 ),
1302 intercept=bagged_intercept[idx],
1303 bag=internal_bags[idx],
1304 init_scores=scores_bags[idx],
1305 iter_term_features=combinations(
1306 range(n_features_in), 2
1307 ),
1308 exclude=exclude,
1309 exclude_features=exclude_features,
1310 calc_interaction_flags=interaction_flags,
1311 max_cardinality=max_cardinality,
1312 min_samples_leaf=min_samples_leaf,
1313 min_hessian=min_hessian,
1314 reg_alpha=reg_alpha,
1315 reg_lambda=reg_lambda,
1316 max_delta_step=max_delta_step,
1317 create_interaction_flags=(
1318 Native.CreateInteractionFlags_DifferentialPrivacy
1319 if is_differential_privacy
1320 else Native.CreateInteractionFlags_Default
1321 ),
1322 objective=objective,
1323 acceleration=develop.get_option("acceleration"),
1324 experimental_params=None,
1325 n_output_interactions=0,
1326 develop_options=develop._develop_options,
1327 )
1328 for idx in range(self.outer_bags)
1329 )
1331 # Select merged pairs
1332 pair_ranks = {}
File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:2072, in Parallel.__call__(self, iterable)
2066 # The first item from the output is blank, but it makes the interpreter
2067 # progress until it enters the Try/Except block of the generator and
2068 # reaches the first `yield` statement. This starts the asynchronous
2069 # dispatch of the tasks to the workers.
2070 next(output)
-> 2072 return output if self.return_generator else list(output)
File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:1682, in Parallel._get_outputs(self, iterator, pre_dispatch)
1679 yield
1681 with self._backend.retrieval_context():
-> 1682 yield from self._retrieve()
1684 except GeneratorExit:
1685 # The generator has been garbage collected before being fully
1686 # consumed. This aborts the remaining tasks if possible and warn
1687 # the user if necessary.
1688 self._exception = True
File ~/work/notes/notes/.venv/lib/python3.10/site-packages/joblib/parallel.py:1800, in Parallel._retrieve(self)
1789 if self.return_ordered:
1790 # Case ordered: wait for completion (or error) of the next job
1791 # that have been dispatched and not retrieved yet. If no job
(...)
1795 # control only have to be done on the amount of time the next
1796 # dispatched job is pending.
1797 if (nb_jobs == 0) or (
1798 self._jobs[0].get_status(timeout=self.timeout) == TASK_PENDING
1799 ):
-> 1800 time.sleep(0.01)
1801 continue
1803 elif nb_jobs == 0:
1804 # Case unordered: jobs are added to the list of jobs to
1805 # retrieve `self._jobs` only once completed or in error, which
(...)
1811 # timeouts before any other dispatched job has completed and
1812 # been added to `self._jobs` to be retrieved.
KeyboardInterrupt:
統計的推論#
[2601.18857v1] Statistical Inference for Explainable Boosting Machines