最適輸送#
2つの確率分布の距離を測ったりするのに使える技術。
最適輸送だと点群や集合の比較ができたり応用の幅が広く、確率分布の比較にも使える
交差エントロピー等との違い(メリット)は、誤差を非対称に(クラスAをBと間違えた場合とその逆の場合の損失を異なる値に)できること
離散的な操作も微分できる
ランキング、最短経路問題など
離散最適輸送は線形計画#
ヒストグラムの最適輸送距離の定式化#
入力:
比較するヒストグラム
各点の距離を表す行列
出力:ヒストグラムの距離
最適輸送距離を以下の最適化問題の最適値と定義する
ここで決定変数
cvxpyで解く例#
※ちゃんとやるならPOTパッケージを使うべき
import numpy as np
# 離散分布を想定
a = np.array([0.2,0.5,0.2,0.1])
b = np.array([0.3,0.3,0.4,0.0])
C = np.array([
[0,2,2,2],
[2,0,1,2],
[2,1,0,2],
[2,2,2,0]
])
import cvxpy as cp
n = C.shape[0]
P = cp.Variable((n, n))
objective = cp.Minimize(cp.multiply(C, P).sum())
constraints = [
P >= 0,
cp.sum(P, axis=1) == a, # jについて総和をとったものがa_iと等しい
cp.sum(P, axis=0) == b, # iについて総和をとったものがb_jと等しい
]
prob = cp.Problem(objective, constraints)
prob.solve()
0.3999999840304966
# 最適輸送距離 OT(a, b, C)
print(f"最適輸送距離 OT(a, b, C): {prob.value:.3f}")
最適輸送距離 OT(a, b, C): 0.400
print(f"行列P:\n{P.value.round(3)}")
行列P:
[[ 0.2 -0. -0. -0. ]
[ 0. 0.3 0.2 -0. ]
[ 0. -0. 0.2 -0. ]
[ 0.1 0. 0. 0. ]]
点群の最適輸送距離の定式化#
入力:
比較する点群
各点の距離を表す関数
出力:点群の距離
最適輸送距離を以下の最適化問題の最適値と定義する
ここで決定変数
連続最適輸送#
連続の場合は難しいので、サンプリングして点群にしたり双対にしたりする(最適輸送入門 - Speaker Deck)