最適輸送#

2つの確率分布の距離を測ったりするのに使える技術。

最適輸送だと点群や集合の比較ができたり応用の幅が広く、確率分布の比較にも使える

交差エントロピー等との違い(メリット)は、誤差を非対称に(クラスAをBと間違えた場合とその逆の場合の損失を異なる値に)できること

離散的な操作も微分できる

離散最適輸送は線形計画#

ヒストグラムの最適輸送距離の定式化#

  • 入力:

    • 比較するヒストグラム \(\boldsymbol{a}, \boldsymbol{b} \in \mathbb{R}^n\)

    • 各点の距離を表す行列 \(\boldsymbol{C} \in \mathbb{R}^{n\times n}\)

  • 出力:ヒストグラムの距離 \(\text{OT}(\boldsymbol{a}, \boldsymbol{b}, \boldsymbol{C})\)

  • 最適輸送距離を以下の最適化問題の最適値と定義する

\[\begin{split} \begin{align} \underset{P \in \mathbb{R}^{n \times n}}{\operatorname{minimize}} & \sum_{i=1}^n \sum_{j=1}^n C_{ij} P_{i j} \quad(総コスト)\\ \text { s.t. } & P_{i j} \geq 0 \quad \forall i, j \quad(輸送量は非負)\\ & \sum_{j=1}^n P_{i j}= a_i \quad \forall i \quad(余りなし)\\ & \sum^n_{i=1} P_{i j}= b_j \quad \forall j \quad(不足なし) \end{align} \end{split}\]

ここで決定変数\(P_{ij}\)は点\(i\)から点\(j\)に輸送する量を表す

cvxpyで解く例#

※ちゃんとやるならPOTパッケージを使うべき

import numpy as np

# 離散分布を想定
a = np.array([0.2,0.5,0.2,0.1])
b = np.array([0.3,0.3,0.4,0.0])
C = np.array([
    [0,2,2,2],
    [2,0,1,2],
    [2,1,0,2],
    [2,2,2,0]
])


import cvxpy as cp
n = C.shape[0]
P = cp.Variable((n, n))

objective = cp.Minimize(cp.multiply(C, P).sum())
constraints = [
    P >= 0,
    cp.sum(P, axis=1) == a, # jについて総和をとったものがa_iと等しい
    cp.sum(P, axis=0) == b, # iについて総和をとったものがb_jと等しい
]
prob = cp.Problem(objective, constraints)
prob.solve()
0.3999999840304966
# 最適輸送距離 OT(a, b, C)
print(f"最適輸送距離 OT(a, b, C): {prob.value:.3f}")
最適輸送距離 OT(a, b, C): 0.400
print(f"行列P:\n{P.value.round(3)}")
行列P:
[[ 0.2 -0.  -0.  -0. ]
 [ 0.   0.3  0.2 -0. ]
 [ 0.  -0.   0.2 -0. ]
 [ 0.1  0.   0.   0. ]]

点群の最適輸送距離の定式化#

  • 入力:

    • 比較する点群 \(\left\{x_1, \cdots, x_n\right\},\left\{y_1, \cdots, y_m\right\} \subset \mathcal{X}\)

    • 各点の距離を表す関数 \(C: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}\)

  • 出力:点群の距離

  • 最適輸送距離を以下の最適化問題の最適値と定義する

\[\begin{split} \begin{align} \underset{P \in \mathbb{R}^{n \times m}}{\operatorname{minimize}} & \sum_{i=1}^n \sum_{j=1}^m C\left(x_i, y_j\right) P_{i j} \quad(総コスト)\\ \text { s.t. } & P_{i j} \geq 0 \quad \forall i, j \quad(輸送量は非負)\\ & \sum_{j=1}^m P_{i j}=\frac{1}{n} \quad \forall i \quad(余りなし)\\ & \sum^n_{i=1} P_{i j}=\frac{1}{m} \quad \forall j \quad(不足なし) \end{align} \end{split}\]

ここで決定変数\(P_{ij}\)は点\(i\)から点\(j\)に輸送する量を表す

連続最適輸送#

連続の場合は難しいので、サンプリングして点群にしたり双対にしたりする(最適輸送入門 - Speaker Deck

関連文献#