WGAN などで使われるWasserstein距離は確率分布の間の距離のひとつです。本稿では、離散型確率分布に対するWasserstein距離の定義と、Pythonによる計算の例を示します。
Wasserstein距離の定義
ふたつのカテゴリカル分布$a$、$b$それぞれの次元を$m$、$n$とします。このふたつの分布の各次元の間で、確率の移動のコスト行列 $C \in \mathbb{R}^{n \times m}$ が与えられたとき、
ふたつの分布を周辺分布にもつような結合確率分布の空間
U(a,b) := \{P \in \mathbb{R}^{n \times m} : P \mathbb{1}_m = a \ {\rm and} \ P^T \mathbb{1}_n = b\}
の中で、最小の総移動コストを求める問題
L_C(a,b) := \min_{P\in U(a,b)} \langle C, P\rangle
をKantorovichの最適輸送問題と呼びます。ここで、 $\mathbb{1}_m$ は、$m$次元のすべての要素が$1$のベクトル、$ \langle C, P\rangle:= \sum_{i,j} C_{i,j} P_{i,j} $ です。
コスト行列$C$が、ある距離行列$D$、パラメーター$p>1$に対して、$C = D^p$で与えられたとき、分布$a$、$b$間のp-Wassserstein距離$W_p(a,b)$は、次にように定義されます。
W_p(a,b) := L_{D^p}(a,b)^{1/p}
特に、$p=1$のとき、Earth movers距離と呼びます。
Wasserstein距離の計算
Wasserstein距離の計算、すなわちKantrovichの最適輸送問題は次のように線形計画問題として解くことができます。
まず、$p_i \in R^n$、$c_i \in R^n$で
P = \left(\begin{array}{cccc} p_1 & p_2 & \cdots & p_m \end{array}\right)
C = \left(\begin{array}{cccc} c_1 & c_2 & \cdots & c_m \end{array}\right)
と表すと、コスト関数は、
L_C(a,b) := \min_{P\in U(a,b)} \left(\begin{array}{cccc}{}^tc_1 &{}^tc_2 &\cdots & {}^tc_m \end{array}\right) \left(\begin{array}{c}p_1 \\ p_2 \\ \vdots \\ p_m\end{array}\right)
と書き直せます。また、$\phi_{i} \in \mathbb{R}^n$を$i$番目の要素が$1$、それ以外の要素が$0$のベクトルとすると、結合確率についての制約は
\begin{array}{c}
\left(\begin{array}{ccc}
\phi_{1} & \cdots & \phi_1 \\
\phi_2 & \cdots & \phi_2 \\
\vdots & & \vdots \\
\phi_n & \cdots & \phi_n \\
\end{array}\right)\\
\left(\begin{array}{ccc}
\mathbb{1}_n & 0 & \cdots & 0\\
0 & \mathbb{1}_n & \cdots & 0\\
\vdots & & & \vdots \\
0 & \cdots & 0 & \mathbb{1}_n \\
\end{array}\right)
\end{array}
\left(\begin{array}{c} p_1 \\ p_2 \\ \vdots \\ p_m\end{array}\right)
=\left(\begin{array}{c}a \\ b\end{array}\right)
と書き直せます。
これで線形な制約下で線形なコスト関数を最小化する $p_i$ を求める線形計画問題として解くことができます。以下でPythonを使った計算例を示します。
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
確率分布 $a$、$b$
n = 10
np.random.seed(0)
a = np.abs(np.random.randn(n))
a = a / np.sum(a)
b = np.abs(np.random.randn(n))
b = b / np.sum(b)
plt.figure()
plt.bar(np.arange(n),a, width=1.0)
plt.figure()
plt.bar(np.arange(n),b, width=1.0)
結合確率についての制約式
cond_b = np.concatenate((a,b),axis=0)
cond_a = np.zeros((2*n,n*n))
for i in range(n):
ind_vector_a = np.zeros(n)
ind_vector_a[i] = 1
cond_a[i,:] = np.tile(ind_vector_a, n)
ind_vector_b = np.ones(n)
cond_a[n+i,(n*i):(n*(i+1))] = ind_vector_b
距離行列 $D$、コスト行列$C$
from matplotlib import cm
d_mat = np.zeros([n,n])
for i in range(n):
d_mat[i,(i+1):n] = np.arange(1,n-(i))
i_lower = np.tril_indices(n, -1)
d_mat[i_lower] = d_mat.T[i_lower]
p = 1
c_mat = d_mat**p
c_mat_reshape = np.squeeze(np.asarray(c_mat.reshape((1, n*n))))
plt.imshow(c_mat, cmap=cm.gist_heat)
plt.axis('off')
plt.colorbar()
線形計画問題を解く
from scipy.optimize import linprog
sol = linprog(c_mat_reshape, A_eq=cond_a, b_eq=cond_b)
1-Wasserstein距離は
sol.fun # -> 1.3601368298912992!
結合確率分布は
from matplotlib import cm
p_mat = np.reshape(sol.x, (n,n))
plt.imshow(p_mat, cmap=cm.gist_heat)
plt.axis('off')
plt.colorbar()
結合確率分布の周辺分布がもとの確率分布$a$、$b$と一致していることの確認
p_mat_margin_col = np.sum(p_mat, axis=1)
p_mat_margin_row = np.sum(p_mat, axis=0)
plt.figure()
plt.bar(np.arange(n),a, width=1.0, label='$a$')
plt.bar(np.arange(n), p_mat_margin_row, label='$P\mathbb{1}$')
plt.legend()
plt.figure()
plt.bar(np.arange(n),b, width=1.0, label='$b$')
plt.bar(np.arange(n), p_mat_margin_col, label='$P^T \mathbb{1}$')
plt.legend()