影響関数(influence function)#

データ点\(z_{test}\)の予測において訓練データ点\(z\)が与えた影響の大きさを評価する関数\(\mathcal{I}(z, z_{test})\)を推定する技術。

アイデア#

  • 経験リスク\(R(\theta) = \frac{1}{n} \sum^n_{i=1} L(z_i, \theta)\)の最小化による学習アルゴリズムを前提とする。

  • 訓練データ全件で学習したモデル\(\hat\theta\)と、訓練データからデータ点\(z\)を抜いて学習したモデル\(\hat\theta_{-z}\)との差分\(\hat{\theta}_{-z} - \hat{\theta}\)でインスタンス\(z\)の影響度がわかる

  • インスタンス\(z\)のデータ点\(z_{test}\)に対する影響度は、誤差\(L(z_{test}, \theta)\)の大きさからわかる

  • データ点\(z_{test}\)の予測において訓練データ点\(z\)が与えた影響の大きさ \(L(z_{test}, \hat{\theta}_{-z}) - L(z_{test}, \hat{\theta})\)を推定する

文献#

  • Koh & Liang (2017)が提案

  • Hara et al. (2019)はconvexでない損失関数にも使えるよう拡張したらしい

  • Guo et al. (2020)はFastIFを提案:k-最近傍のデータに探索範囲を絞るなどして最大80倍の高速化

  • Sharchilev et al. (2018)は勾配ブースティング決定木に向けたinfluence functionを提案

    • Influence Functionはモデルのパラメータ\(\theta\)の微分によって近似推定するため、決定木ベースのアルゴリズムなど微分不可能なアルゴリズムでは計算できないため

Notation#

  • データ点\(z=(x, y)\)

  • 誤差関数\(L(z, \theta)\)

    • 経験リスク:\(R(\theta) = \frac{1}{n} \sum^n_{i=1} L(z_i, \theta)\)

  • \(N\)訓練データ点が訓練集合\(\mathcal{Z}\)に含まれるとする

  • 標準的な経験リスク最小化:\(\hat{\theta} = \arg \min_{\theta} \frac{1}{n} \sum^n_{i=1} L(z_i, \theta)\)

Leave-One-Out#

1つのインスタンス\(z\)を抜いて訓練した場合にどれだけモデル\(\theta\)が変わるかを考える

\[\begin{split} \hat{\theta}_{-z} - \hat{\theta}\\ \hat{\theta}_{-z} := \arg \min_{\theta \in \Theta} \sum_{z_i \neq z} L(z_i, \theta) \end{split}\]

モデルの変化がわかればテストデータ点\(z_{test}\)に対する誤差の変化も評価できる

\[ L(z_{test}, \hat{\theta}_{-z}) - L(z_{test}, \hat{\theta}) \]

\(n\)個の訓練データ全部を評価するには\(n\)回学習し直す必要があるので現実的なアプローチではない

Influence Function#

アイデア#

数学的なトリックを活用して再学習を避けつつLOOを近似していく。

訓練データ全部を使った経験リスクに、データ点\(z\)についての誤差\(L(z, \theta)\)を重み付きで足したリスク関数で訓練したパラメータ\(\hat{\theta}_{\epsilon,z}\)を考える

\[ \hat{\theta}_{\epsilon,z} := \arg \min_{\theta} \frac{1}{n} \sum^n_{i=1} L(z_i, \theta) + \epsilon L(z, \theta) \]

\(z\)を入れることによる変化分がとれるので、これの\(\epsilon\)での微分の\(\epsilon = 0\)のときの値をパラメータについての上側のinfluence functionとする

\[ \mathcal{I}_{up, params}(z) := \left . \frac{d \hat{\theta}_{\epsilon,z} }{ d \epsilon } \right | _{\epsilon = 0} = -H^{-1}_{\hat{\theta}} \nabla_{\theta} L(z, \hat{\theta}) \]
導出

Koh & Liang (2017)のAppendixより)

up-weightedされた経験リスクのもとでの推定量は

\[ \hat{\theta}_{\epsilon, z} = \arg \min_{\theta \in \Theta} \big \{ R(\theta) + \epsilon L(z, \theta) \big \} \]

これともとの推定量の差を\(\Delta_{\epsilon}\)とする

\[ \Delta_{\epsilon} = \hat{\theta}_{\epsilon, z} - \hat{\theta} \]

第二項は\(\epsilon\)と関係無いので、\(\epsilon\)で微分すると消える

\[ \frac{d \Delta_{\epsilon} }{d \epsilon} = \frac{d \hat{\theta}_{\epsilon, z} }{d \epsilon} \]

\(\hat{\theta}_{\epsilon, z}\) は arg minの解なので最適性条件を満たす、つまりup-weightedされた経験リスクを微分してゼロとなるポイント

\[ 0 = \nabla R(\hat{\theta}_{\epsilon, z}) + \epsilon \nabla L(z, \hat{\theta}_{\epsilon, z}) \]

\(\epsilon \to 0\)とすると

\[ \hat{\theta}_{\epsilon, z} = \arg \min_{\theta} \frac{1}{n} \sum^n_{i=1} L(z_i, \theta) + \epsilon L(z, \theta) \to \arg \min_{\theta} \frac{1}{n} \sum^n_{i=1} L(z_i, \theta) + 0 = \theta \]

ゆえに\(\hat{\theta}_{\epsilon, z} \to \hat{\theta}\)なので、テイラー展開\(f(x) = f(a) + f(a) (x - a) + \cdots\)を用いると

\[\begin{split} \begin{align} 0 &= \big[ \nabla R(\hat{\theta}) + \epsilon \nabla L(z, \hat{\theta}) \big] + \big[ \nabla^2 R(\hat{\theta}) + \epsilon \nabla^2 L(z, \hat{\theta}) \big] \Delta_{\epsilon} + o(\| \Delta_{\epsilon} \|) \\ &\approx \big[ \nabla R(\hat{\theta}) + \epsilon \nabla L(z, \hat{\theta}) \big] + \big[ \nabla^2 R(\hat{\theta}) + \epsilon \nabla^2 L(z, \hat{\theta}) \big] \Delta_{\epsilon} \end{align} \end{split}\]

となる。整理すると

\[ \Delta_{\epsilon} \approx - \big[ \nabla^2 R(\hat{\theta}) + \epsilon \nabla^2 L(z, \hat{\theta}) \big]^{-1} \big[ \nabla R(\hat{\theta}_{\epsilon, z}) + \epsilon \nabla L(z, \hat{\theta}) \big] \]

であり、\(\hat{\theta}\)\(R(\hat{\theta})\)を最小化して\(\nabla R(\hat{\theta})=0\)になると考え、\(o(\epsilon)\)の(\(\epsilon\)に比例する)項を消すと

\[ \Delta_{\epsilon} \approx - \big[ \nabla^2 R(\hat{\theta}) + \epsilon \nabla^2 L(z, \hat{\theta}) \big]^{-1} \big[ \underbrace{ \nabla R(\hat{\theta}_{\epsilon, z}) }_{=0} + \epsilon \nabla L(z, \hat{\theta}) \big] \]

より

\[\begin{split} \Delta_{\epsilon} \approx - \nabla^2 R(\hat{\theta})^{-1} \nabla L(z, \hat{\theta}) \epsilon \\ = - H_{\hat{\theta}}^{-1} \nabla L(z, \hat{\theta}) \epsilon \end{split}\]

\(\epsilon \nabla L(z, \hat{\theta})\)は残るのはなぜ????)

よって

\[ \left . \frac{d \hat{\theta}_{\epsilon,z} }{ d \epsilon } \right | _{\epsilon = 0} = -H^{-1}_{\hat{\theta}} \nabla_{\theta} L(z, \hat{\theta}) \]

LOOの近似#

もし\(\epsilon = -\frac{1}{n}\)なら

\[ \hat{\theta}_{\epsilon,z} = \arg \min_{\theta} \frac{1}{n} \sum^n_{i=1} L(z_i, \theta) - \frac{1}{n} L(z, \theta) \]

もし\(z\)が訓練データに含まれるなら、訓練データから\(z\)を除去した場合と同様なので、LOOの考え方を近似できる

\[ \hat{\theta}_{-z} - \hat{\theta} \approx -\frac{1}{n} \mathcal{I}_{up, params}(z) \]

損失へのupweighting#

データ点\(z_{test}\)におけるinfluenceを計算するため、微分の連鎖律を使う

\[\begin{split} \begin{align} \mathcal{I}_{up, loss}(z, z_{test}) :&= \left . \frac{d L(z_{test}, \hat{\theta}_{\epsilon, z}) }{ d \epsilon } \right | _{\epsilon = 0} \\ &= \nabla_{\theta} L(z_{test}, \hat{\theta})^\top \left . \frac{d \hat{\theta}_{\epsilon, z} }{ d \epsilon } \right | _{\epsilon = 0} \\ &= - \nabla_{\theta} L(z_{test}, \hat{\theta})^\top H^{-1}_{\hat{\theta}} \nabla_{\theta} L(z, \hat{\theta}) \end{align} \end{split}\]

ヘッセ行列の計算の高速化#

\(n\)個のデータと\(p\)次元パラメータがあるとき、ヘッセ行列の逆行列は\(O(n p^3 + p^3)\)の計算量を要するので削減したい

  1. 共役勾配(conjugate gradients: CG)法

    • 連立一次方程式に帰着させて近似解を得る方法

  2. Hessian-vector products (HVPs)

    • ヘッセ行列とベクトルの積を近似推定する数学的トリックを使う方法

  3. Stochastic estimation (Agarwal et al., 2017)

    • SGDのようにランダムサンプリングしてCG法のiterationを回す

FastFI#

Guo, H., Rajani, N. F., Hase, P., Bansal, M., & Xiong, C. (2020). Fastif: Scalable influence functions for efficient model interpretation and debugging. arXiv preprint arXiv:2012.15781.

Influence Functionは計算が重い

  1. データ点の評価は\(O(n)\)

  2. モデルパラメータのinverse Hessianの計算コストが高い

  3. 上記の計算は並列可能であるが、先行研究のアルゴリズムでは直列

FastIfのアイデア

  1. 全データを探索するのではなく、fast nearest neighbor search(Johnson et al., 2017)で探索範囲を狭め、桁違いに計算量を抑える

  2. Hessianの推定において、品質を保ちつつ時間を半分以下にするハイパーパラメータ集合を識別

  3. シンプルに並列計算へ拡張し、さらに2倍高速化

実験においてほとんどのケースで全体で2桁程度の高速化が確認された

LeafInfluence#

参考文献#