LoginSignup
6
2

More than 5 years have passed since last update.

WAE-MMDをChainerで実装

Last updated at Posted at 2018-03-27

Wasserstein Auto-Encoders (MMD)

TensorFlow実装を参考に、ChainerでWAE-MMDを実装してみました。

動機

以前作ったモデルをWebDNNを用いてブラウザで動作するデモを作ろうと思ったのですが、一部のオペレータがWebDNNでは対応していないようで処理できませんでした。ならば理解を深めるためにもChainerで実装しよう、ということでやってみてある程度結果を出すことができました。
出来上がったWebDNNによるデモは以下で見られます。

構造

基本的にはDCGANをベースとしたネットワーク構造です。conv/deconv + batch normを3層持たせています。

Encoderの事前訓練

Encoderの入力画像に対して出力される潜在空間の分布を平均0、分散1となるような分布に近づけるよう事前訓練をします。バッチサイズ分の乱数を生成して、Encoderの出力の平均と分散の二乗平均誤差を最小化します。

目的関数

論文では以下のように定義されています。

D_{WAE}(P_X, P_G) = \inf_{Q(Z|X)\in \mathcal{Q}} \mathbb{E}_{P_X} \mathbb{E}_{Q(Z|X)}[c(X, G(Z))] + \lambda \cdot \mathcal{D}_Z(Q_Z,P_Z)

最初の項はautoencoder($\mathcal{Q}$)の復元誤差で、後ろの項は2つの分布$Q_Z$(エンコーダの潜在変数)と$P_Z$(サンプリングで得る値)のダイバージェンスです。VAEであればKLダイバージェンスを使う部分ですが、ここをGAN, MMDによって求めるところがWAEの特徴です。

復元誤差

$||x - y||_2^2$を使って求めています。TensorFlow実装では、L1, L2ノルムも選択できるようになっています。

ダイバージェンス

TensorFlowでは
* GAN
* MMD - RBFカーネル
* MMD - IMQカーネル
が選択できるようになっていますが、今の実装ではIMQカーネルを用いたMMDのみを実装しています。

MMD (Maximum Mean Discrepancy)

MMDに関するわかりやすい解説がQiita内にありました。

SVMでよく使われるカーネル法を、確率分布の測度として拡張したものがKernel mean embeddingと呼ばれるもので、2つの確率分布を核再生ヒルベルト空間(RKHS, Reproducing Kernel Hirbert Space)へ写像します。
写像された特徴空間内の2つの確率分布の距離がMMD、だそうです。元々は2標本検定に用いるノンパラメトリックな手法のようです。

論文では以下のように定義しています。

MMD_k(P_Z, Q_Z) = ||\int_Z k(z \cdot)dP_Z(z) - \int_Z k(z \cdot)dQ_Z(z)||_{\mathcal{H}_k}

検定で用いる場合にはRKHS上でのノルムの二乗を用いるようですが、ここでは二乗しない値を使うようです。

\mathcal{D}_Z(Q_Z,P_Z) = MMD_k(P_Z, Q_Z)

として、係数$\lambda$をかけて計算をします。$\lambda$は大きな数(TF実装では100)から始めて、徐々に減衰させます(自分の実装では未実装)。
MMDはSGDによって求められると論文にあります。尺度としては2つの確率分布の距離であり、それを最小化することでオートエンコーダーから生成される潜在変数の確率分布$P_Z$を、目的の分布$Q_Z$に近づけることができるのだと理解しました。
ノンパラメトリックな手法なので、任意分布の形を$Q_Z$(実装中では正規分布)で近似できる、ということなのだろうと理解しています。

inverse multiquadric kernel (IMQ)

自分の理解がどんどん怪しくなってくるのですが、IMQに至ってはいまだによく理解できていません。もうちょっと以下の資料をしっかりと読んで理解したいところです。

オリジナルの実装ではガウシアンカーネルを用いた手法も選べるようになっていました。

出力例

iter 28396:
image.png

今後の展望

  • λの減衰を実装する
  • ~WebDNNデモを公開する~ (done)
  • WAE-GANを実装する
  • MMD-RBF(ガウシアンカーネル)を実装する
  • ほかのネットワーク構造、パラメータに対応する
  • IMQを理解する
  • 解説に間違いがあったら直す
6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2