Quantized Training of LightGBM#
概要#
LightGBMのような最近のGBDTで使われている決定木では、葉の出力
ここで
この
現代的なGBDTの学習の流れ#
(Shi et al. (2022). Quantized training of gradient boosting decision trees.より)
まず、GBDTの学習の流れを再確認し、notationを決める
勾配ブースティング決定木(GBDT)は複数の決定木を組み合わせるアンサンブル学習のアプローチをとる。
各iterationでは現状の予測値に基づくGradientとHessianを計算し、負の勾配を近似するように決定木を学習する。
となる。
葉
と表記することにすると、反復
と表すことができる。
ここで
最適な木構造を探すのは困難であるため、木は貪欲かつ反復的に訓練される。
葉
葉
LightGBMでは最適分割点の探索を高速化するためにヒストグラムを使う。histogram based GBDTの基本的なアイデアは特徴量の値をbinsに分割する。histogramのbinsは、そのbinに含まれるデータのgradientsとhessiansの総和が記録されている。binsの境界値のみが分割候補点になる。
Algorithm 1 Histogram Construction for Leaf
Input: Gradients
, Hessians
Input: Bin data data, Data indices in leaf denoted by
Output: Histogram
fordo
bin
bin
end for
伝統的には
Framework for Quantized Training#
まず
すべての訓練サンプルの
1次の導関数
2次の導関数
それゆえ、区間の長さは
となる。これにより、低ビット幅の勾配は
で計算できる。ここで
なお、もし
詳細なrounding strategyは4.2節に書く。
Algorithm 1の
B = 2
g = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
delta_g = max([abs(gi) for gi in g]) / (2**(B-1) - 1)
tilde_g = [round(gi / delta_g) for gi in g]
tilde_g
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
分岐による損失の減少分
の
我々は2から4bitの量子化された勾配が十分よい精度をもたらすことを発見した。また、6.1節と7.4.3節で議論するが、ヒストグラムにおける累積した低ビット幅の勾配には16bit整数で十分であった。それゆえ、大部分の演算は低ビット幅の整数によって行われる。浮動小数点数の演算が必要になるのは本来の勾配とヘシアンとsplit gainを計算するときだけである。とくに、split gainは
と推定される。ここで勾配の統計量のスケールは
Figure 1は量子化されたGBDTのワークフローを要約している。
Rounding Strategies and Leaf-Value Refitting#
最も近い整数への丸め込み(round-to-nearest)
では精度が大幅に低下することがわかった。
代わりに、 確率的な丸め込み(stochastic rounding)
を用いる(
split gainは勾配の総和で計算されるため、確率的な丸め込みは総和への不偏推定量となる。すなわち
である。
確率的な丸め込みの重要性はニューラルネットの量子化学習[13]とDimBoost[16]のヒストグラム分解でも認識されている。
量子化された勾配により、最適なleaf valueは
となり、多くのケースで
しかし、ランキングなど一部の損失関数では、木の成長が止まったあとに元の勾配でleaf valueをrefittingする方法が精度を向上させることがわかった。BitBoost [8] も同様の方法をとっているが、BitBoostと違い、本手法はsplit gainのヘシアンを木の成長中も考慮するがBitBoostではヘシアンを定数として扱い真のヘシアンをleaf valueのrefittingのときだけ使う。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
def stochastic_round(x: float) -> int:
f = np.floor(x)
c = np.ceil(x)
return np.random.choice([f, c], p=[c - x, x - f])
np.random.seed(0)
e = 0.001
x = np.linspace(e, 1-e, 1000)
y = [stochastic_round(xi) for xi in x]
x_c = pd.cut(x, bins=10)
def mid_value(cat):
return cat.left + ((cat.right - cat.left) / 2)
cat_to_mid = {cat: mid_value(cat) for cat in x_c.categories}
df = pd.DataFrame({
"x": x,
"y": y,
"x_c": x_c.map(cat_to_mid, na_action=None)
})
agg = df.groupby("x_c", observed=True)["y"].agg(["mean", "std"]).reset_index()
fig, ax = plt.subplots()
ax.scatter(x, y, alpha=0.1, label="StochasticRound(x)", edgecolors="steelblue")
ax.errorbar(agg["x_c"], agg["mean"], yerr=agg["std"], alpha=0.9,
fmt="o", color="gray", label="mean & std of StochasticRound(x)")
# xerr=(x.max() - x.min()) / 10 / 2,
# for cat in x_c.categories:
# ax.axvline(cat.left, alpha=0.5, linestyle="--", color="gray")
# ax.axvline(cat.right, alpha=0.5, linestyle="--", color="gray")
ax.set(
xlabel="x",
ylabel="StochasticRound(x)",
title="",
)
ax.legend()
fig.show()

確率的な丸め込みによる#
確率的な丸め込みにより、「split gain推定の誤差は高い確率で小さい値に制限される」というsection 5の定理が提供できる。
実装#
注意点
精度を上げるため、量子化前の勾配でleaf valueを再計算する場合、計算時間もモデルのサイズも悪化する
実験:量子化の有無による差を比較#
import lightgbm as lgb
print(f"{lgb.__version__=}")
def gen_dataset(seed=0):
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=5_000, n_features=10, noise=0.5, random_state=seed)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=0)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, train_size=0.8, random_state=0)
import lightgbm as lgb
lgb_train = lgb.Dataset(X_train, y_train)
lgb_val = lgb.Dataset(X_val, y_val, reference=lgb_train)
return lgb_train, lgb_val, X_test, y_test
def train(params, lgb_train, lgb_val):
logs = {}
model = lgb.train(
params,
lgb_train,
num_boost_round=100_000,
valid_sets=[lgb_train, lgb_val],
callbacks=[
lgb.early_stopping(stopping_rounds=100),
lgb.log_evaluation(period=100000),
lgb.record_evaluation(logs)
]
)
return model, logs
from tqdm import tqdm
from time import time
from pathlib import Path
from sklearn.metrics import root_mean_squared_error
import joblib
import pandas as pd
def get_model_size(model):
model_path = Path("model.joblib")
with open(model_path, "wb") as f:
joblib.dump(model, f)
size_mb = model_path.stat().st_size / 1024**2
model_path.unlink()
return size_mb
def trials(name, params, n_trials=10):
train_times = []
pred_times = []
rmses = []
model_sizes = []
for i in tqdm(range(n_trials)):
lgb_train, lgb_val, X_test, y_test = gen_dataset(seed=i)
t0 = time()
model, _logs = train(params, lgb_train, lgb_val)
t1 = time()
train_times.append(t1 - t0)
t0 = time()
y_pred = model.predict(X_test)
t1 = time()
pred_times.append(t1 - t0)
rmse = root_mean_squared_error(y_test, y_pred)
rmses.append(rmse)
size = get_model_size(model)
model_sizes.append(size)
result = pd.DataFrame({
"Training time": train_times,
"Inference time": pred_times,
"RMSE": rmses,
"Model Size (MB)": model_sizes
})
result["条件"] = name
return result
import os
print(f"{os.cpu_count()=}")
n_trials = 10
results = []
params = {
'device_type': 'cpu',
'num_threads': os.cpu_count(),
'objective': 'mse',
'metric': 'rmse',
'num_leaves': 31,
'learning_rate': 0.1,
'feature_fraction': 0.8,
'verbose': -1,
'seed': 0,
'deterministic': True,
}
# 量子化なし
result = trials(name="量子化なし", params=params, n_trials=n_trials)
results.append(result)
# 量子化あり
params["use_quantized_grad"] = True
params["num_grad_quant_bins"] = 4
params["quant_train_renew_leaf"] = False
result = trials(name="量子化あり・Renewなし", params=params, n_trials=n_trials)
results.append(result)
params["quant_train_renew_leaf"] = True
result = trials(name="量子化あり・Renewあり", params=params, n_trials=n_trials)
results.append(result)
Show code cell output
lgb.__version__='4.6.0'
os.cpu_count()=4
0%| | 0/10 [00:00<?, ?it/s]
Training until validation scores don't improve for 100 rounds
Early stopping, best iteration is:
[1781] training's rmse: 0.115047 valid_1's rmse: 41.0182
10%|█ | 1/10 [00:01<00:12, 1.41s/it]
Training until validation scores don't improve for 100 rounds
20%|██ | 2/10 [00:02<00:10, 1.35s/it]
Early stopping, best iteration is:
[1650] training's rmse: 0.0972269 valid_1's rmse: 27.826
Training until validation scores don't improve for 100 rounds
Early stopping, best iteration is:
[3078] training's rmse: 0.00376388 valid_1's rmse: 46.6675
30%|███ | 3/10 [00:04<00:12, 1.77s/it]
Training until validation scores don't improve for 100 rounds
40%|████ | 4/10 [00:06<00:09, 1.59s/it]
Early stopping, best iteration is:
[1642] training's rmse: 0.119979 valid_1's rmse: 32.1281
Training until validation scores don't improve for 100 rounds
Early stopping, best iteration is:
[3313] training's rmse: 0.00137981 valid_1's rmse: 25.5328
50%|█████ | 5/10 [00:08<00:09, 1.90s/it]
Training until validation scores don't improve for 100 rounds
60%|██████ | 6/10 [00:10<00:06, 1.68s/it]
Early stopping, best iteration is:
[1616] training's rmse: 0.130356 valid_1's rmse: 36.1539
Training until validation scores don't improve for 100 rounds
Early stopping, best iteration is:
[3704] training's rmse: 0.000824281 valid_1's rmse: 39.8657
70%|███████ | 7/10 [00:12<00:06, 2.01s/it]
Training until validation scores don't improve for 100 rounds
80%|████████ | 8/10 [00:13<00:03, 1.70s/it]
Early stopping, best iteration is:
[1251] training's rmse: 0.252942 valid_1's rmse: 17.2222
Training until validation scores don't improve for 100 rounds
90%|█████████ | 9/10 [00:14<00:01, 1.53s/it]
Early stopping, best iteration is:
[1454] training's rmse: 0.237015 valid_1's rmse: 38.8602
Training until validation scores don't improve for 100 rounds
Early stopping, best iteration is:
[2294] training's rmse: 0.020392 valid_1's rmse: 23.7294
100%|██████████| 10/10 [00:16<00:00, 1.60s/it]
100%|██████████| 10/10 [00:16<00:00, 1.66s/it]
0%| | 0/10 [00:00<?, ?it/s]
Training until validation scores don't improve for 100 rounds
10%|█ | 1/10 [00:00<00:08, 1.05it/s]
Early stopping, best iteration is:
[1093] training's rmse: 0.508697 valid_1's rmse: 36.2037
Training until validation scores don't improve for 100 rounds
20%|██ | 2/10 [00:01<00:05, 1.44it/s]
Early stopping, best iteration is:
[427] training's rmse: 3.18506 valid_1's rmse: 24.8431
Training until validation scores don't improve for 100 rounds
30%|███ | 3/10 [00:02<00:05, 1.24it/s]
Early stopping, best iteration is:
[944] training's rmse: 0.878267 valid_1's rmse: 41.9654
Training until validation scores don't improve for 100 rounds
40%|████ | 4/10 [00:03<00:04, 1.34it/s]
Early stopping, best iteration is:
[677] training's rmse: 1.58701 valid_1's rmse: 29.5476
Training until validation scores don't improve for 100 rounds
50%|█████ | 5/10 [00:03<00:03, 1.47it/s]
Early stopping, best iteration is:
[572] training's rmse: 1.90926 valid_1's rmse: 22.9558
Training until validation scores don't improve for 100 rounds
60%|██████ | 6/10 [00:04<00:03, 1.24it/s]
Early stopping, best iteration is:
[1201] training's rmse: 0.331789 valid_1's rmse: 33.0541
Training until validation scores don't improve for 100 rounds
70%|███████ | 7/10 [00:05<00:02, 1.10it/s]
Early stopping, best iteration is:
[1292] training's rmse: 0.278446 valid_1's rmse: 36.915
Training until validation scores don't improve for 100 rounds
80%|████████ | 8/10 [00:06<00:01, 1.36it/s]
Early stopping, best iteration is:
[261] training's rmse: 3.74099 valid_1's rmse: 16.3203
Training until validation scores don't improve for 100 rounds
90%|█████████ | 9/10 [00:07<00:00, 1.29it/s]
Early stopping, best iteration is:
[991] training's rmse: 0.710558 valid_1's rmse: 36.1448
Training until validation scores don't improve for 100 rounds
100%|██████████| 10/10 [00:07<00:00, 1.33it/s]
100%|██████████| 10/10 [00:07<00:00, 1.29it/s]
Early stopping, best iteration is:
[636] training's rmse: 1.47725 valid_1's rmse: 22.2107
0%| | 0/10 [00:00<?, ?it/s]
Training until validation scores don't improve for 100 rounds
Early stopping, best iteration is:
[2524] training's rmse: 0.237055 valid_1's rmse: 33.2555
10%|█ | 1/10 [00:02<00:18, 2.07s/it]
Training until validation scores don't improve for 100 rounds
20%|██ | 2/10 [00:03<00:13, 1.69s/it]
Early stopping, best iteration is:
[1635] training's rmse: 0.713115 valid_1's rmse: 23.7664
Training until validation scores don't improve for 100 rounds
20%|██ | 2/10 [00:05<00:21, 2.63s/it]
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[3], line 117
114 results.append(result)
116 params["quant_train_renew_leaf"] = True
--> 117 result = trials(name="量子化あり・Renewあり", params=params, n_trials=n_trials)
118 results.append(result)
Cell In[3], line 62, in trials(name, params, n_trials)
59 lgb_train, lgb_val, X_test, y_test = gen_dataset(seed=i)
61 t0 = time()
---> 62 model, _logs = train(params, lgb_train, lgb_val)
63 t1 = time()
64 train_times.append(t1 - t0)
Cell In[3], line 21, in train(params, lgb_train, lgb_val)
19 def train(params, lgb_train, lgb_val):
20 logs = {}
---> 21 model = lgb.train(
22 params,
23 lgb_train,
24 num_boost_round=100_000,
25 valid_sets=[lgb_train, lgb_val],
26 callbacks=[
27 lgb.early_stopping(stopping_rounds=100),
28 lgb.log_evaluation(period=100000),
29 lgb.record_evaluation(logs)
30 ]
31 )
32 return model, logs
File /usr/local/lib/python3.10/site-packages/lightgbm/engine.py:322, in train(params, train_set, num_boost_round, valid_sets, valid_names, feval, init_model, keep_training_booster, callbacks)
310 for cb in callbacks_before_iter:
311 cb(
312 callback.CallbackEnv(
313 model=booster,
(...)
319 )
320 )
--> 322 booster.update(fobj=fobj)
324 evaluation_result_list: List[_LGBM_BoosterEvalMethodResultType] = []
325 # check evaluation result.
File /usr/local/lib/python3.10/site-packages/lightgbm/basic.py:4155, in Booster.update(self, train_set, fobj)
4152 if self.__set_objective_to_none:
4153 raise LightGBMError("Cannot update due to null objective function.")
4154 _safe_call(
-> 4155 _LIB.LGBM_BoosterUpdateOneIter(
4156 self._handle,
4157 ctypes.byref(is_finished),
4158 )
4159 )
4160 self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
4161 return is_finished.value == 1
KeyboardInterrupt:
import matplotlib.pyplot as plt
import seaborn as sns
import japanize_matplotlib
res = pd.concat(results)
fig, axes = plt.subplots(figsize=[10, 8], nrows=2, ncols=2)
fig.subplots_adjust(hspace=0.3)
i, j = 0, 0
col = 'Training time'
sns.kdeplot(data=res, x=col, hue="条件", common_norm=False, ax=axes[i, j])
axes[i, j].set(title=f"{col}")
i, j = 0, 1
col = 'Inference time'
sns.kdeplot(data=res, x=col, hue="条件", common_norm=False, ax=axes[i, j])
axes[i, j].set(title=f"{col}")
i, j = 1, 0
col = 'RMSE'
sns.kdeplot(data=res, x=col, hue="条件", common_norm=False, ax=axes[i, j])
axes[i, j].set(title=f"{col}")
i, j = 1, 1
col = 'Model Size (MB)'
sns.kdeplot(data=res, x=col, hue="条件", common_norm=False, ax=axes[i, j])
axes[i, j].set(title=f"{col}")
fig.show()
