最適輸送#

2つの確率分布の距離を測ったりするのに使える技術。

最適輸送だと点群や集合の比較ができたり応用の幅が広く、確率分布の比較にも使える

交差エントロピー等との違い(メリット)は、誤差を非対称に(クラスAをBと間違えた場合とその逆の場合の損失を異なる値に)できること

離散的な操作も微分できる

離散最適輸送は線形計画#

ヒストグラムの最適輸送距離の定式化#

  • 入力:

    • 比較するヒストグラム \(\boldsymbol{a}, \boldsymbol{b} \in \mathbb{R}^n\)

    • 各点の距離を表す行列 \(\boldsymbol{C} \in \mathbb{R}^{n\times n}\)

  • 出力:ヒストグラムの距離 \(\text{OT}(\boldsymbol{a}, \boldsymbol{b}, \boldsymbol{C})\)

  • 最適輸送距離を以下の最適化問題の最適値と定義する

\[\begin{split} \begin{align} \underset{P \in \mathbb{R}^{n \times n}}{\operatorname{minimize}} & \sum_{i=1}^n \sum_{j=1}^n C_{ij} P_{i j} \quad(総コスト)\\ \text { s.t. } & P_{i j} \geq 0 \quad \forall i, j \quad(輸送量は非負)\\ & \sum_{j=1}^n P_{i j}= a_i \quad \forall i \quad(余りなし)\\ & \sum^n_{i=1} P_{i j}= b_j \quad \forall j \quad(不足なし) \end{align} \end{split}\]

ここで決定変数\(P_{ij}\)は点\(i\)から点\(j\)に輸送する量を表す

cvxpyで解く例#

※ちゃんとやるならPOTパッケージを使うべき

import numpy as np

# 離散分布を想定
a = np.array([0.2,0.5,0.2,0.1])
b = np.array([0.3,0.3,0.4,0.0])
C = np.array([
    [0,2,2,2],
    [2,0,1,2],
    [2,1,0,2],
    [2,2,2,0]
])


import cvxpy as cp
n = C.shape[0]
P = cp.Variable((n, n))

objective = cp.Minimize(cp.multiply(C, P).sum())
constraints = [
    P >= 0,
    cp.sum(P, axis=1) == a, # jについて総和をとったものがa_iと等しい
    cp.sum(P, axis=0) == b, # iについて総和をとったものがb_jと等しい
]
prob = cp.Problem(objective, constraints)
prob.solve()
/usr/local/lib/python3.10/site-packages/cvxpy/reductions/solvers/solving_chain.py:336: FutureWarning: 
    Your problem is being solved with the ECOS solver by default. Starting in 
    CVXPY 1.5.0, Clarabel will be used as the default solver instead. To continue 
    using ECOS, specify the ECOS solver explicitly using the ``solver=cp.ECOS`` 
    argument to the ``problem.solve`` method.
    
  warnings.warn(ECOS_DEPRECATION_MSG, FutureWarning)
0.3999999997741462
# 最適輸送距離 OT(a, b, C)
print(f"最適輸送距離 OT(a, b, C): {prob.value:.3f}")
最適輸送距離 OT(a, b, C): 0.400
print(f"行列P:\n{P.value.round(3)}")
行列P:
[[ 0.2 -0.  -0.  -0. ]
 [-0.   0.3  0.2 -0. ]
 [-0.  -0.   0.2 -0. ]
 [ 0.1  0.  -0.   0. ]]

点群の最適輸送距離の定式化#

  • 入力:

    • 比較する点群 \(\left\{x_1, \cdots, x_n\right\},\left\{y_1, \cdots, y_m\right\} \subset \mathcal{X}\)

    • 各点の距離を表す関数 \(C: \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}\)

  • 出力:点群の距離

  • 最適輸送距離を以下の最適化問題の最適値と定義する

\[\begin{split} \begin{align} \underset{P \in \mathbb{R}^{n \times m}}{\operatorname{minimize}} & \sum_{i=1}^n \sum_{j=1}^m C\left(x_i, y_j\right) P_{i j} \quad(総コスト)\\ \text { s.t. } & P_{i j} \geq 0 \quad \forall i, j \quad(輸送量は非負)\\ & \sum_{j=1}^m P_{i j}=\frac{1}{n} \quad \forall i \quad(余りなし)\\ & \sum^n_{i=1} P_{i j}=\frac{1}{m} \quad \forall j \quad(不足なし) \end{align} \end{split}\]

ここで決定変数\(P_{ij}\)は点\(i\)から点\(j\)に輸送する量を表す

連続最適輸送#

連続の場合は難しいので、サンプリングして点群にしたり双対にしたりする(最適輸送入門 - Speaker Deck

関連文献#