LoginSignup
37
34

More than 3 years have passed since last update.

Wasserstein距離

Posted at

WGAN などで使われるWasserstein距離は確率分布の間の距離のひとつです。本稿では、離散型確率分布に対するWasserstein距離の定義と、Pythonによる計算の例を示します。

Wasserstein距離の定義

ふたつのカテゴリカル分布$a$、$b$それぞれの次元を$m$、$n$とします。このふたつの分布の各次元の間で、確率の移動のコスト行列 $C \in \mathbb{R}^{n \times m}$ が与えられたとき、
ふたつの分布を周辺分布にもつような結合確率分布の空間

U(a,b) := \{P \in \mathbb{R}^{n \times m} : P \mathbb{1}_m = a \ {\rm and} \ P^T \mathbb{1}_n = b\}

の中で、最小の総移動コストを求める問題

L_C(a,b) := \min_{P\in U(a,b)} \langle C, P\rangle

をKantorovichの最適輸送問題と呼びます。ここで、 $\mathbb{1}_m$ は、$m$次元のすべての要素が$1$のベクトル、$ \langle C, P\rangle:= \sum_{i,j} C_{i,j} P_{i,j} $ です。

コスト行列$C$が、ある距離行列$D$、パラメーター$p>1$に対して、$C = D^p$で与えられたとき、分布$a$、$b$間のp-Wassserstein距離$W_p(a,b)$は、次にように定義されます。

W_p(a,b) := L_{D^p}(a,b)^{1/p}

特に、$p=1$のとき、Earth movers距離と呼びます。

Wasserstein距離の計算

Wasserstein距離の計算、すなわちKantrovichの最適輸送問題は次のように線形計画問題として解くことができます。

まず、$p_i \in R^n$、$c_i \in R^n$で

P = \left(\begin{array}{cccc} p_1 & p_2 & \cdots & p_m \end{array}\right)
C = \left(\begin{array}{cccc} c_1 & c_2 & \cdots & c_m \end{array}\right)

と表すと、コスト関数は、

L_C(a,b) := \min_{P\in U(a,b)} \left(\begin{array}{cccc}{}^tc_1 &{}^tc_2 &\cdots & {}^tc_m \end{array}\right)  \left(\begin{array}{c}p_1 \\ p_2 \\ \vdots \\ p_m\end{array}\right)

と書き直せます。また、$\phi_{i} \in \mathbb{R}^n$を$i$番目の要素が$1$、それ以外の要素が$0$のベクトルとすると、結合確率についての制約は

\begin{array}{c}
\left(\begin{array}{ccc}
\phi_{1} & \cdots  & \phi_1 \\
\phi_2 & \cdots  & \phi_2 \\
\vdots & & \vdots \\
\phi_n & \cdots  & \phi_n \\
\end{array}\right)\\
\left(\begin{array}{ccc}
\mathbb{1}_n & 0 & \cdots & 0\\
0 & \mathbb{1}_n  & \cdots & 0\\
\vdots & & & \vdots \\
0  & \cdots & 0 & \mathbb{1}_n \\
\end{array}\right)
\end{array} 
\left(\begin{array}{c} p_1 \\ p_2 \\ \vdots \\ p_m\end{array}\right) 
=\left(\begin{array}{c}a \\ b\end{array}\right)

と書き直せます。
これで線形な制約下で線形なコスト関数を最小化する $p_i$ を求める線形計画問題として解くことができます。以下でPythonを使った計算例を示します。

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

確率分布 $a$、$b$

n = 10
np.random.seed(0)
a = np.abs(np.random.randn(n))
a = a / np.sum(a)
b = np.abs(np.random.randn(n))
b = b / np.sum(b)

plt.figure()
plt.bar(np.arange(n),a, width=1.0)
plt.figure()
plt.bar(np.arange(n),b, width=1.0)

a.png
b.png

結合確率についての制約式

cond_b = np.concatenate((a,b),axis=0)
cond_a = np.zeros((2*n,n*n))
for i in range(n):
    ind_vector_a = np.zeros(n)
    ind_vector_a[i] = 1
    cond_a[i,:] = np.tile(ind_vector_a, n)
    ind_vector_b = np.ones(n)
    cond_a[n+i,(n*i):(n*(i+1))] =  ind_vector_b

距離行列 $D$、コスト行列$C$

from matplotlib import cm
d_mat = np.zeros([n,n])
for i in range(n):
    d_mat[i,(i+1):n] = np.arange(1,n-(i))
i_lower = np.tril_indices(n, -1)
d_mat[i_lower] = d_mat.T[i_lower]
p = 1
c_mat = d_mat**p
c_mat_reshape = np.squeeze(np.asarray(c_mat.reshape((1, n*n))))

plt.imshow(c_mat, cmap=cm.gist_heat)
plt.axis('off')
plt.colorbar()

c.png

線形計画問題を解く

from scipy.optimize import linprog
sol = linprog(c_mat_reshape, A_eq=cond_a, b_eq=cond_b)

1-Wasserstein距離は

sol.fun # -> 1.3601368298912992!

結合確率分布は

from matplotlib import cm
p_mat = np.reshape(sol.x, (n,n))
plt.imshow(p_mat, cmap=cm.gist_heat)
plt.axis('off')
plt.colorbar()

p.png

結合確率分布の周辺分布がもとの確率分布$a$、$b$と一致していることの確認

p_mat_margin_col = np.sum(p_mat, axis=1)
p_mat_margin_row = np.sum(p_mat, axis=0)


plt.figure()
plt.bar(np.arange(n),a, width=1.0, label='$a$')
plt.bar(np.arange(n), p_mat_margin_row, label='$P\mathbb{1}$')
plt.legend()
plt.figure()
plt.bar(np.arange(n),b, width=1.0, label='$b$')
plt.bar(np.arange(n), p_mat_margin_col, label='$P^T \mathbb{1}$')
plt.legend()

a_result.png
b_result.png

参考文献

37
34
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
37
34