影響関数(influence function)#

データ点ztestの予測において訓練データ点zが与えた影響の大きさを評価する関数I(z,ztest)を推定する技術。

アイデア#

  • 経験リスクR(θ)=1ni=1nL(zi,θ)の最小化による学習アルゴリズムを前提とする。

  • 訓練データ全件で学習したモデルθ^と、訓練データからデータ点zを抜いて学習したモデルθ^zとの差分θ^zθ^でインスタンスzの影響度がわかる

  • インスタンスzのデータ点ztestに対する影響度は、誤差L(ztest,θ)の大きさからわかる

  • データ点ztestの予測において訓練データ点zが与えた影響の大きさ L(ztest,θ^z)L(ztest,θ^)を推定する

文献#

  • Koh & Liang (2017)が提案

  • Hara et al. (2019)はconvexでない損失関数にも使えるよう拡張したらしい

  • Guo et al. (2020)はFastIFを提案:k-最近傍のデータに探索範囲を絞るなどして最大80倍の高速化

  • Sharchilev et al. (2018)は勾配ブースティング決定木に向けたinfluence functionを提案

    • Influence Functionはモデルのパラメータθの微分によって近似推定するため、決定木ベースのアルゴリズムなど微分不可能なアルゴリズムでは計算できないため

Notation#

  • データ点z=(x,y)

  • 誤差関数L(z,θ)

    • 経験リスク:R(θ)=1ni=1nL(zi,θ)

  • N訓練データ点が訓練集合Zに含まれるとする

  • 標準的な経験リスク最小化:θ^=argminθ1ni=1nL(zi,θ)

Leave-One-Out#

1つのインスタンスzを抜いて訓練した場合にどれだけモデルθが変わるかを考える

θ^zθ^θ^z:=argminθΘzizL(zi,θ)

モデルの変化がわかればテストデータ点ztestに対する誤差の変化も評価できる

L(ztest,θ^z)L(ztest,θ^)

n個の訓練データ全部を評価するにはn回学習し直す必要があるので現実的なアプローチではない

Influence Function#

アイデア#

数学的なトリックを活用して再学習を避けつつLOOを近似していく。

訓練データ全部を使った経験リスクに、データ点zについての誤差L(z,θ)を重み付きで足したリスク関数で訓練したパラメータθ^ϵ,zを考える

θ^ϵ,z:=argminθ1ni=1nL(zi,θ)+ϵL(z,θ)

zを入れることによる変化分がとれるので、これのϵでの微分のϵ=0のときの値をパラメータについての上側のinfluence functionとする

Iup,params(z):=dθ^ϵ,zdϵ|ϵ=0=Hθ^1θL(z,θ^)
導出

Koh & Liang (2017)のAppendixより)

up-weightedされた経験リスクのもとでの推定量は

θ^ϵ,z=argminθΘ{R(θ)+ϵL(z,θ)}

これともとの推定量の差をΔϵとする

Δϵ=θ^ϵ,zθ^

第二項はϵと関係無いので、ϵで微分すると消える

dΔϵdϵ=dθ^ϵ,zdϵ

θ^ϵ,z は arg minの解なので最適性条件を満たす、つまりup-weightedされた経験リスクを微分してゼロとなるポイント

0=R(θ^ϵ,z)+ϵL(z,θ^ϵ,z)

ϵ0とすると

θ^ϵ,z=argminθ1ni=1nL(zi,θ)+ϵL(z,θ)argminθ1ni=1nL(zi,θ)+0=θ

ゆえにθ^ϵ,zθ^なので、テイラー展開f(x)=f(a)+f(a)(xa)+を用いると

0=[R(θ^)+ϵL(z,θ^)]+[2R(θ^)+ϵ2L(z,θ^)]Δϵ+o(Δϵ)[R(θ^)+ϵL(z,θ^)]+[2R(θ^)+ϵ2L(z,θ^)]Δϵ

となる。整理すると

Δϵ[2R(θ^)+ϵ2L(z,θ^)]1[R(θ^ϵ,z)+ϵL(z,θ^)]

であり、θ^R(θ^)を最小化してR(θ^)=0になると考え、o(ϵ)の(ϵに比例する)項を消すと

Δϵ[2R(θ^)+ϵ2L(z,θ^)]1[R(θ^ϵ,z)=0+ϵL(z,θ^)]

より

Δϵ2R(θ^)1L(z,θ^)ϵ=Hθ^1L(z,θ^)ϵ

ϵL(z,θ^)は残るのはなぜ????)

よって

dθ^ϵ,zdϵ|ϵ=0=Hθ^1θL(z,θ^)

LOOの近似#

もしϵ=1nなら

θ^ϵ,z=argminθ1ni=1nL(zi,θ)1nL(z,θ)

もしzが訓練データに含まれるなら、訓練データからzを除去した場合と同様なので、LOOの考え方を近似できる

θ^zθ^1nIup,params(z)

損失へのupweighting#

データ点ztestにおけるinfluenceを計算するため、微分の連鎖律を使う

Iup,loss(z,ztest):=dL(ztest,θ^ϵ,z)dϵ|ϵ=0=θL(ztest,θ^)dθ^ϵ,zdϵ|ϵ=0=θL(ztest,θ^)Hθ^1θL(z,θ^)

ヘッセ行列の計算の高速化#

n個のデータとp次元パラメータがあるとき、ヘッセ行列の逆行列はO(np3+p3)の計算量を要するので削減したい

  1. 共役勾配(conjugate gradients: CG)法

    • 連立一次方程式に帰着させて近似解を得る方法

  2. Hessian-vector products (HVPs)

    • ヘッセ行列とベクトルの積を近似推定する数学的トリックを使う方法

  3. Stochastic estimation (Agarwal et al., 2017)

    • SGDのようにランダムサンプリングしてCG法のiterationを回す

FastFI#

Guo, H., Rajani, N. F., Hase, P., Bansal, M., & Xiong, C. (2020). Fastif: Scalable influence functions for efficient model interpretation and debugging. arXiv preprint arXiv:2012.15781.

Influence Functionは計算が重い

  1. データ点の評価はO(n)

  2. モデルパラメータのinverse Hessianの計算コストが高い

  3. 上記の計算は並列可能であるが、先行研究のアルゴリズムでは直列

FastIfのアイデア

  1. 全データを探索するのではなく、fast nearest neighbor search(Johnson et al., 2017)で探索範囲を狭め、桁違いに計算量を抑える

  2. Hessianの推定において、品質を保ちつつ時間を半分以下にするハイパーパラメータ集合を識別

  3. シンプルに並列計算へ拡張し、さらに2倍高速化

実験においてほとんどのケースで全体で2桁程度の高速化が確認された

LeafInfluence#

Case-deletion importance sampling estimators#

[0807.0725] Case-deletion importance sampling estimators: Central limit theorems and related results

Bayesian Modelのためのinfluence function的なアプローチ。データセットDからサンプルiを除いたデータDiの事後分布をP(θDi)とすると、サンプルを抜いたことの影響を分布の密度比P(θDi)P(θD) で評価できる。そしてこの密度比はデータを抜いて再学習しなくても計算可能、らしい。まだちゃんと読んでない。

参考文献#