36
23

More than 1 year has passed since last update.

【機械学習】最適輸送とPOTライブラリについて

Last updated at Posted at 2021-12-01

はじめに

最適輸送という概念を学んだので、自分用の理解メモのため簡単にまとめます。

最適輸送について

確率分布を比較する手法として使える。確率分布の山を移動させて一致するのにかかるコストというイメージ。下の場合だと確率分布μ0から確率分布μ1に山を移す際、砂山をイメージしてざざっと動かす感じ。μ0の任意の箇所からμ1の任意の箇所に移せるが、最も労力のかからない動かし方が最適輸送になる。

400px-Continuous_optimal_transport.png
Transportation theory (mathematics), Wikipediaより引用

離散分布間の最適輸送距離の定式化

入力

  • 比較する離散分布
\mu_0 = \{a_1,a_2,...,a_m\},  \mu_1 =\{b_1,b_2,...,b_n\}

  • 各点の距離コスト(metric cost)を表す行列
M \in \mathbb{R}^{m \times n}_{+}

出力:最適輸送距離

最適輸送距離を次の最適化問題の最適値$\gamma^*$と定義する。

ここで、$\gamma_{ij}$はiからjへの輸送コストを示す。

\gamma^* = \underset{\gamma\in\mathbb{R}^{n \times m}_+}{arg\mathrm{min}}\sum_{i,j}^n\gamma_{i,j}M_{i,j}\\


\mathrm{s.t.} \quad M_{ij} \geq 0,\forall i,j, \quad \sum_{j=1}^m\gamma_{ij}=\mu_0,\forall i,\quad\sum_{i=1}^n\gamma_{ij}=\mu_1,\forall j

なお、KLダイバージェンスも2つの確率分布の比較で用いられる。PとQの2つの確率分布があった場合、次が離散分布時のKLダイバージェンスの定義。

D_{KL}(p||q) = \sum_i{P(i)\log\frac{P(i)}{Q(i)}}

KLダイバージェンスはpythonならscipy.special.kl_divで計算可能。

しかし、KLダイバージェンスは各要素ごとに独立に項を足し合わせている関係で、適切な比較にならない場合がある。最適輸送の考え方だと要素同士の距離のうまく定義することで、他の要素も考慮した比較になるため、より良い比較にできる場合がある。ただし、最適輸送はKLよりも断然計算コストが高くなる点には注意する

なお、離散分布の場合は線形計画となり、既存のPOT等のライブラリで解くことが可能。連続分布の場合は割愛したが直接解くのが難しく、該当の連続分布からサンプリングをして点群比較に帰着させて解くなどする。

POTライブラリについて

最適輸送問題はPOTライブラリで解くことができる。インストールはPOT: Python Optimal Transportの通り、pipでインストール可能。

POTは先の離散分布時の線形計画最適化問題のソルバが実装されており、ot.emd(a,b,M)を呼び出すと、最適輸送行列γ*が返ってくる。aとbは質量ヒストグラムでMは距離行列。なお、Mは2つの分布を入力とすることでPOTライブラリでも計算可能。
quick startの2D Optimal transport between empirical distributionsを写経して試してみる。

import numpy as np
import matplotlib.pylab as pl
import ot
import ot.plot

n = 100 # 各分布のサンプル数

#1つ目の分布パラメータ
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

#2つ目の分布パラメータ
mu_t = np.array([4, 4])
cov_t = np.array([[1, -.8], [-.8, 1]])

#2Dガウス分布で作成それぞれの分布を作成
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
#各点の重さ。今回は全て1/nとしている
a, b = np.ones((n,)) / n, np.ones((n,)) / n 

# 距離を定義する
# ot.distで、xsとxtの距離行列を計算する。
# https://pythonot.github.io/all.html?highlight=dist#ot.dist
M = ot.dist(xs, xt) 
M /= M.max()

#作成した2つの分布を可視化
pl.figure(1)
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')

7afbf0f9-ec59-4cb0-bdb3-bcc7f2a34ec7.png

#計算した距離行列の可視化
pl.figure(2)
pl.imshow(M, interpolation='nearest') #2次元配列を画像として表示
pl.title('Cost matrix M')

9680dd2e-ad7b-47a3-bb47-8d65b67ae672.png

#最適輸送行列の計算
G0 = ot.emd(a, b, M)

#最適輸送行列の可視化。前と後の対応関係がわかる。
pl.figure(3)
pl.imshow(G0, interpolation='nearest') #2次元配列を画像として表示
pl.title('OT matrix G0')

00749660-9656-4f82-b316-b825c8e958a7.png

#2つの分布図に対しての輸送対応を可視化
pl.figure(4)
ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.6, .6, .6]) #見やすさのため、サンプルと線の色を変えました。
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('OT matrix with samples')

b7158067-9e01-465e-ad2f-13ad14b3925e.png

ひとまず、2つの分布の最適輸送についてはこのサンプルコードを参考に適宜入れ替えするだけで実施できそう。

最後に

最適輸送の計算アルゴリズムの研究動向によると、色変換や形状変換といった画像処理系の応用、言語処理系、データ生成やデータ補完などでも応用があるなど、道具としても面白そうかなと。

参考文献、関連文献

36
23
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
36
23