最適輸送#

2つの確率分布の距離を測ったりするのに使える技術。

最適輸送だと点群や集合の比較ができたり応用の幅が広く、確率分布の比較にも使える

交差エントロピー等との違い(メリット)は、誤差を非対称に(クラスAをBと間違えた場合とその逆の場合の損失を異なる値に)できること

離散的な操作も微分できる

離散最適輸送は線形計画#

ヒストグラムの最適輸送距離の定式化#

  • 入力:

    • 比較するヒストグラム a,bRn

    • 各点の距離を表す行列 CRn×n

  • 出力:ヒストグラムの距離 OT(a,b,C)

  • 最適輸送距離を以下の最適化問題の最適値と定義する

minimizePRn×ni=1nj=1nCijPij() s.t. Pij0i,j()j=1nPij=aii()i=1nPij=bjj()

ここで決定変数Pijは点iから点jに輸送する量を表す

cvxpyで解く例#

※ちゃんとやるならPOTパッケージを使うべき

import numpy as np

# 離散分布を想定
a = np.array([0.2,0.5,0.2,0.1])
b = np.array([0.3,0.3,0.4,0.0])
C = np.array([
    [0,2,2,2],
    [2,0,1,2],
    [2,1,0,2],
    [2,2,2,0]
])


import cvxpy as cp
n = C.shape[0]
P = cp.Variable((n, n))

objective = cp.Minimize(cp.multiply(C, P).sum())
constraints = [
    P >= 0,
    cp.sum(P, axis=1) == a, # jについて総和をとったものがa_iと等しい
    cp.sum(P, axis=0) == b, # iについて総和をとったものがb_jと等しい
]
prob = cp.Problem(objective, constraints)
prob.solve()
0.3999999840304966
# 最適輸送距離 OT(a, b, C)
print(f"最適輸送距離 OT(a, b, C): {prob.value:.3f}")
最適輸送距離 OT(a, b, C): 0.400
print(f"行列P:\n{P.value.round(3)}")
行列P:
[[ 0.2 -0.  -0.  -0. ]
 [ 0.   0.3  0.2 -0. ]
 [ 0.  -0.   0.2 -0. ]
 [ 0.1  0.   0.   0. ]]

点群の最適輸送距離の定式化#

  • 入力:

    • 比較する点群 {x1,,xn},{y1,,ym}X

    • 各点の距離を表す関数 C:X×XR

  • 出力:点群の距離

  • 最適輸送距離を以下の最適化問題の最適値と定義する

minimizePRn×mi=1nj=1mC(xi,yj)Pij() s.t. Pij0i,j()j=1mPij=1ni()i=1nPij=1mj()

ここで決定変数Pijは点iから点jに輸送する量を表す

連続最適輸送#

連続の場合は難しいので、サンプリングして点群にしたり双対にしたりする(最適輸送入門 - Speaker Deck

関連文献#