最適輸送#
最適輸送(optimal transport: OT) は2つの点群(確率分布)のあいだの
距離を求める
対応関係を得る
変換する
といったことができる技術。
交差エントロピー等との違い(メリット)は、 誤差を非対称にできる こと(クラスAをBと間違えた場合とその逆の場合の損失を異なる値にしたりできる)
離散的な操作も微分できる
ランキング、最短経路問題など
離散最適輸送は線形計画#
ヒストグラムの最適輸送距離の定式化#
入力:
比較するヒストグラム \(\boldsymbol{a}, \boldsymbol{b} \in \mathbb{R}^n\)
各点の距離を表す行列 \(\boldsymbol{C} \in \mathbb{R}^{n\times n}\)
出力:ヒストグラムの距離 \(\text{OT}(\boldsymbol{a}, \boldsymbol{b}, \boldsymbol{C})\)
最適輸送距離を以下の最適化問題の最適値と定義する
ここで決定変数\(P_{ij}\)は点\(i\)から点\(j\)に輸送する量を表す
cvxpyで解く例#
※ちゃんとやるならPOTパッケージを使うべき
import numpy as np
# 離散分布を想定
a = np.array([0.2,0.5,0.2,0.1])
b = np.array([0.3,0.3,0.4,0.0])
C = np.array([
[0,2,2,2],
[2,0,1,2],
[2,1,0,2],
[2,2,2,0]
])
import cvxpy as cp
n = C.shape[0]
P = cp.Variable((n, n))
objective = cp.Minimize(cp.multiply(C, P).sum())
constraints = [
P >= 0,
cp.sum(P, axis=1) == a, # jについて総和をとったものがa_iと等しい
cp.sum(P, axis=0) == b, # iについて総和をとったものがb_jと等しい
]
prob = cp.Problem(objective, constraints)
prob.solve()
0.3999999840304966
# 最適輸送距離 OT(a, b, C)
print(f"最適輸送距離 OT(a, b, C): {prob.value:.3f}")
最適輸送距離 OT(a, b, C): 0.400
print(f"行列P:\n{P.value.round(3)}")
行列P:
[[ 0.2 -0. -0. -0. ]
[ 0. 0.3 0.2 -0. ]
[ 0. -0. 0.2 -0. ]
[ 0.1 0. 0. 0. ]]
点群の最適輸送距離の定式化#
入力:
比較する点群 \(\left\{x_1, \cdots, x_n\right\},\left\{y_1, \cdots, y_m\right\} \subset \mathcal{X}\)
各点の距離を表す関数 \(C: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}\)
出力:点群の距離
最適輸送距離を以下の最適化問題の最適値と定義する
ここで決定変数\(P_{ij}\)は点\(i\)から点\(j\)に輸送する量を表す
連続最適輸送#
連続の場合は難しいので、サンプリングして点群にしたり双対にしたりする(最適輸送入門 - Speaker Deck)