最適輸送#
2つの確率分布の距離を測ったりするのに使える技術。
最適輸送だと点群や集合の比較ができたり応用の幅が広く、確率分布の比較にも使える
交差エントロピー等との違い(メリット)は、誤差を非対称に(クラスAをBと間違えた場合とその逆の場合の損失を異なる値に)できること
離散的な操作も微分できる
ランキング、最短経路問題など
離散最適輸送は線形計画#
ヒストグラムの最適輸送距離の定式化#
入力:
比較するヒストグラム \(\boldsymbol{a}, \boldsymbol{b} \in \mathbb{R}^n\)
各点の距離を表す行列 \(\boldsymbol{C} \in \mathbb{R}^{n\times n}\)
出力:ヒストグラムの距離 \(\text{OT}(\boldsymbol{a}, \boldsymbol{b}, \boldsymbol{C})\)
最適輸送距離を以下の最適化問題の最適値と定義する
\[\begin{split}
\begin{align}
\underset{P \in \mathbb{R}^{n \times n}}{\operatorname{minimize}}
& \sum_{i=1}^n \sum_{j=1}^n C_{ij} P_{i j} \quad(総コスト)\\
\text { s.t. }
& P_{i j} \geq 0 \quad \forall i, j \quad(輸送量は非負)\\
& \sum_{j=1}^n P_{i j}= a_i \quad \forall i \quad(余りなし)\\
& \sum^n_{i=1} P_{i j}= b_j \quad \forall j \quad(不足なし)
\end{align}
\end{split}\]
ここで決定変数\(P_{ij}\)は点\(i\)から点\(j\)に輸送する量を表す
cvxpyで解く例#
※ちゃんとやるならPOTパッケージを使うべき
import numpy as np
# 離散分布を想定
a = np.array([0.2,0.5,0.2,0.1])
b = np.array([0.3,0.3,0.4,0.0])
C = np.array([
[0,2,2,2],
[2,0,1,2],
[2,1,0,2],
[2,2,2,0]
])
import cvxpy as cp
n = C.shape[0]
P = cp.Variable((n, n))
objective = cp.Minimize(cp.multiply(C, P).sum())
constraints = [
P >= 0,
cp.sum(P, axis=1) == a, # jについて総和をとったものがa_iと等しい
cp.sum(P, axis=0) == b, # iについて総和をとったものがb_jと等しい
]
prob = cp.Problem(objective, constraints)
prob.solve()
0.3999999840304966
# 最適輸送距離 OT(a, b, C)
print(f"最適輸送距離 OT(a, b, C): {prob.value:.3f}")
最適輸送距離 OT(a, b, C): 0.400
print(f"行列P:\n{P.value.round(3)}")
行列P:
[[ 0.2 -0. -0. -0. ]
[ 0. 0.3 0.2 -0. ]
[ 0. -0. 0.2 -0. ]
[ 0.1 0. 0. 0. ]]
点群の最適輸送距離の定式化#
入力:
比較する点群 \(\left\{x_1, \cdots, x_n\right\},\left\{y_1, \cdots, y_m\right\} \subset \mathcal{X}\)
各点の距離を表す関数 \(C: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}\)
出力:点群の距離
最適輸送距離を以下の最適化問題の最適値と定義する
\[\begin{split}
\begin{align}
\underset{P \in \mathbb{R}^{n \times m}}{\operatorname{minimize}}
& \sum_{i=1}^n \sum_{j=1}^m C\left(x_i, y_j\right) P_{i j} \quad(総コスト)\\
\text { s.t. }
& P_{i j} \geq 0 \quad \forall i, j \quad(輸送量は非負)\\
& \sum_{j=1}^m P_{i j}=\frac{1}{n} \quad \forall i \quad(余りなし)\\
& \sum^n_{i=1} P_{i j}=\frac{1}{m} \quad \forall j \quad(不足なし)
\end{align}
\end{split}\]
ここで決定変数\(P_{ij}\)は点\(i\)から点\(j\)に輸送する量を表す
連続最適輸送#
連続の場合は難しいので、サンプリングして点群にしたり双対にしたりする(最適輸送入門 - Speaker Deck)