Help us understand the problem. What is going on with this article?

GAN : DCGAN Part3 - Understanding Wasserstein GAN

目標

Microsoft Cognitive Toolkit (CNTK) を用いた DCGAN の続きです。

Part3 では、Part1 で準備した画像データを用いて CNTK による DCGAN の訓練を行います。
CNTK と NVIDIA GPU CUDA がインストールされていることを前提としています。

導入

GAN : DCGAN Part2 - Training DCGAN model では、Deep Convolutional Generative Adversarial Network (DCGAN) による顔生成モデルを扱ったので、Part3 では Wasserstein GAN を作成して訓練します。

Wasserstein GAN

今回実装した Wasserstein GAN [1] のネットワーク構造は Part2 と同じです。

ただし、Wasserstein GAN では Discriminator の出力層で sigmoid 関数を使用しないため Discriminator は Critic と呼ばれます。また今回の実装では、Critic の Batch Normalization をすべて Layer Normalization [2] に置き換えています。

オリジナルの GAN [3](Vanilla GAN (VGAN) と呼ぶことにします)と Wasserstein GAN (WGAN) の違いについては後ほど説明します。

訓練における諸設定

転置畳み込み・畳み込み層のパラメータの初期値は分散 0.02 の正規分布 [4] に設定しました。

今回実装した損失関数を下式に示します。[5]

\max_{C} \mathbb{E}_{x \sim p_{r}(x)}[C(x)] - \mathbb{E}_{z \sim p_z(z)}[C(G(z))] + \lambda \mathbb{E}_{x' \sim p_{x'}(x')}(||\nabla_{x'} C(x')||_2 - 1)^2 \\
\min_{G} -\mathbb{E}_{z \sim p_z(z)}[C(G(z))] \\
x' = \epsilon x + (1 - \epsilon) G(z), \epsilon \sim U[0, 1]

ここで、$C$, $G$ はそれぞれ Critic と Generator を表し、$x$ は入力画像、$z$ は潜在変数、$p_r$ は本物の画像データの分布、$p_z$ は偽物の画像データを生成する事前分布、$U$ は一様分布を表しています。今回は gradient penalty [5] を用いた Wasserstein GAN を実装し、$\lambda$ は 10 に設定しました。

Generator, Discriminator ともに最適化アルゴリズムは Adam [6] を採用しました。学習率は 1e-4、Adam のハイパーパラメータ $β_1$ は 0.0、$β_2$ は CNTK のデフォルト値に設定しました。

モデルの訓練はミニバッチサイズ 16 のミニバッチ学習によって 50,000 Iteration を実行しました。

実装

実行環境

ハードウェア

・CPU Intel(R) Core(TM) i7-6700K 4.00GHz
・GPU NVIDIA GeForce GTX 1060 6GB

ソフトウェア

・Windows 10 Pro 1909
・CUDA 10.0
・cuDNN 7.6
・Python 3.6.6
・cntk-gpu 2.7
・opencv-contrib-python 4.1.1.26
・numpy 1.17.3
・pandas 0.25.0

実行するプログラム

訓練用のプログラムは GitHub で公開しています。

wgan_training.py

解説

ところどころ証明や厳密性は欠けていますが、GAN の数理について理解を深めたいと思います。

そのために、まず確率分布の尺度である Kullback-Leibler divergence と Jensen-Shannon divergence から始めます。

Kullback-Leibler divergence and Jensen-Shannon divergence

2つの確率分布 $P(x), Q(x)$ の尺度として、Kullback-Leibler divergence が挙げられます。$H$ はエントロピーを表します。

\begin{align}
D_{KL} (P || Q) &= \sum_x P(x) \log \frac{P(x)}{Q(x)}  \\
&= H(P, Q) - H(P)
\end{align}

ただし KL divergence には対称性がない、つまり $D_{KL} (P || Q) \neq D_{KL} (Q || P)$ です。

また、エントロピー $H$ は以下の式で表されます。

H(P, Q) = \mathbb{E}_{x \sim P(x)} [- \log Q(x)] \\
H(P) = \mathbb{E}_{x \sim P(x)} [- \log P(x)]

これを用いて表記を少し変形すると以下のように表せます。

\begin{align}
D_{KL} (P || Q) &= H(P, Q) - H(P) \\
&= \mathbb{E}_{x \sim P(x)} [- \log Q(x)] - \mathbb{E}_{x \sim P(x)} [- \log P(x)] \\
&= \mathbb{E}_{x \sim P(x)} [- \log Q(x) - (- \log P(x))] \\
&= \mathbb{E}_{x \sim P(x)} [\log P(x) - \log Q(x)] \\
\end{align}

一方、KL divergence の派生版である Jensen-Shannon divergence は以下のように定義されます。

D_{JS} (P || Q) = \frac{1}{2} D_{KL} (P || M) + \frac{1}{2} D_{KL} (Q || M) \\
M = \frac{P + Q}{2}

JS divergence は対称性をもち、$0 \leq D_{JS} \leq 1$ となります。したがって、JS divergence が大きいと 2つの分布が似ておらず、逆に JS divergence が小さいと 2つの分布が似ていることになります。

Vanilla GAN

GAN を含む生成モデルは、現実に観測されるデータは何らかの生成モデルをもつという仮説に基づき、そのような生成モデルを獲得することを目指します。

まず、$D, G$ を Discriminator, Generator とし、$p_r$ は本物のデータの分布、$p_z$ は偽物のデータを生成する事前分布とします。また、Discriminator, Generator の評価関数を $V_D, V_G$ とします。

ここで、Discriminator は本物のデータと偽物のデータを識別する問題と考えると、Discriminator の評価関数は以下の式で表せます。

V_D = \mathbb{E}_{x \sim p_r(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))]

さらに、Discriminator と Generator の評価関数の利得の和が 0 になる zero-sum game を導入すると、Generator の評価関数は次式のように定義するのが自然です。

V_G = - V_D

ここで、zero-sum game のナッシュ均衡($V_D$ に関して局所最小かつ $V_G$ に関して局所最小となる解)は minimax 解となることが知られているので、VGAN の損失関数が定義されます。

\min_G \max_D V(G, D) = \mathbb{E}_{x \sim p_r(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))]

ここで、上式の最適化問題がユニークな解 $G^*$ をもち、この解が $p_g = p_r$ となることを示す必要があります。その際、以下の式を暗黙の了解として用います。

\mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] = \mathbb{E}_{x \sim p_g(x)}[\log (1 - D(x))]

VGAN の損失関数を連続関数で考えると、

\begin{align}
V(G, D) &= \int_x p_r(x) \log(D(x))dx + \int_z p_z(z) \log(1 - D(G(z)))dz \\
&= \int_x p_r(x) \log(D(x)) + p_g(x) \log(1 - D(x))dx
\end{align}

ここで、関数 $f(y) = a\log y + b\log(1 - y)$ の臨界点を考えると、微分して 0 になるので、

f'(y) = 0 \Rightarrow \frac{a}{y} - \frac{b}{1 - y} = 0 \Rightarrow y = \frac{a}{a + b}

また、$\frac{a}{a + b}$ における 2回微分を考えると、$a, b \in (0, 1)$ なので上に凸だと分かります

f'' \left( \frac{a}{a + b} \right) = - \frac{a}{(\frac{a}{a + b})^2} - \frac{b}{(1 - \frac{a}{a + b})^2} < 0

したがって、$\frac{a}{a + b}$ で最大となります。よって、

\begin{align}
V(G, D) &= \int_x p_r(x) \log(D(x))dx + \int_z p_z(z) \log(1 - D(G(z)))dz \\
&\leq \int_x \max_y p_r(x) \log(D(x)) + p_g(x) \log(1 - D(x))dx
\end{align}

となり、$D_{G} (x) = \frac{p_r}{p_r + p_g}$ のとき最大値であり、十分にユニークな解であることが分かりました。しかし実際に $D$ の最適解を求めることはできません。なぜなら真の本物のデータの分布 $p_r$ を知る術がないからです。ですが、$p_r$ は $G$ の最適解の存在を示すものであり、訓練において $D$ を近似することに専念すればよいことが分かりました。

次に Generator の最適解を考えるにあたって、GAN の最終到達目標は $p_g = p_r$ になることを再び提示しておきます。このとき $D_{G}^{*}$ は

D_{G}^{*} = \frac{p_r}{p_r + p_g} = \frac{1}{2}

となります。$D_{G}^{*}$ が得られたとき、Generator の最小化問題を考えると、

\begin{align}
\max_D V(G, D_{G}^{*}) &= \mathbb{E}_{x \sim p_r(x)}[\log D_{G}^{*}(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D_{G}^{*}(G(z)))] \\
&= \mathbb{E}_{x \sim p_r(x)}[\log D_{G}^{*}(x)] + \mathbb{E}_{x \sim p_g(x)}[\log (1 - D_{G}^{*}(x))] \\
&= \mathbb{E}_{x \sim p_r(x)} \left[\log \frac{p_r}{p_r + p_g} \right] + \mathbb{E}_{x \sim p_g(x)} \left[\log \frac{p_g}{p_r + p_g} \right] \\
&= \mathbb{E}_{x \sim p_r(x)} [\log p_r - \log (p_r + p_g)] + \mathbb{E}_{x \sim p_g(x)} [\log p_g - \log (p_r + p_g)] \\
&= \mathbb{E}_{x \sim p_r(x)} \left[\log p_r - \log \left(\frac{p_r + p_g}{2} \cdot 2 \right) \right] + \mathbb{E}_{x \sim p_g(x)} \left[\log p_g - \log \left(\frac{p_r + p_g}{2} \cdot 2 \right) \right] \\
&= \mathbb{E}_{x \sim p_r(x)} \left[\log p_r - \log \left(\frac{p_r + p_g}{2} \right) - \log 2 \right] + \mathbb{E}_{x \sim p_g(x)} \left[\log p_g - \log \left(\frac{p_r + p_g}{2} \right) - \log 2 \right] \\
&= \mathbb{E}_{x \sim p_r(x)} \left[\log p_r - \log \left(\frac{p_r + p_g}{2} \right) \right] + \mathbb{E}_{x \sim p_r(x)} [- \log 2] + \mathbb{E}_{x \sim p_g(x)} \left[\log p_g - \log \left(\frac{p_r + p_g}{2} \right) \right] + \mathbb{E}_{x \sim p_g(x)} [ - \log 2] \\
\end{align}

ここで、KL divergence を用いると、

\begin{align}
\max_D V(G, D_{G}^{*}) = D_{KL} \left(p_r \middle| \middle| \frac{p_r + p_g}{2} \right) + D_{KL} \left(p_g \middle| \middle| \frac{p_r + p_g}{2} \right) - \log 4
\end{align}

さらに、$\frac{p_r + p_g}{2} = M$ とすると、JS divergence より、

\max_D V(G, D_{G}^{*}) = 2 \cdot D_{JS} (p_r || p_g) - \log 4

したがって、$p_g = p_r$ のとき、大域的最小値の候補として $- \log 4$ をもつことが分かり、同時に Generator の最小化問題は JS divergence を最小化していると考えることができます。

以上のような理論的背景から、十分な表現力と本物のデータがあれば生成モデルを学習することができますが、依然として GAN の訓練は難しいです。

Wasserstein GAN

GAN を解析した論文 [7] では VGAN の訓練が困難な理由が明らかにされており、同論文の著者らがその解決策として提案したのが WGAN [1] です。

VGAN では Generator は結果的に JS divergence を尺度としていましたが、WGAN では Wasserstein distance を尺度とします。

Wasserstein distance は Earth-Mover(EM) distance とも呼ばれ、輸送最適化問題に基づいた尺度で、ここではある確率分布を別の確率分布に近づけるためのコストを表します。

W(p_r, p_g) = \inf_{\gamma \in \Pi(p_r, p_g)} \mathbb{E}_{(x, y) \sim \gamma} [||x - y||]

Wasserstein distance は KL divergence や JS divergence にはない有益な性質があります。2つの確率分布に重なりがない場合、KL divergence は発散してしまい、JS divergence は $\log 2$ となって微分不可能になってしまいますが、Wasserstein distance は滑らかな値をとるため、勾配法での最適化が安定します。

そして、Wasserstein distance は Kantorovich-Rubinstein 双対性に基づいた変換により、下の式のような損失関数が得られます。ここで、$C, G$ は Critic, Generator を表します。

\min_G \max_{||C||_L \leq K} \mathbb{E}_{x \sim p_r(x)}[C(x)] - \mathbb{E}_{z \sim p_z(z)}[C(G(z))]

ただし、Wasserstein distance には Lipshitz 連続性と呼ばれる制約が課されるため、これを保証するための方法として重みパラメータを clip する方法を用いています。

しかし weight clipping は強引な方法であるため訓練に失敗する場合があり、その改善策として gradient penalty [5] が提案されました。

Gradient penalty では、最適化された Critic において、本物のデータと生成データの間の任意の点に対する勾配の L2ノルムが 1 になるという事実を利用し、損失関数に勾配の L2ノルムが 1 以外のときにペナルティを課す方法です。

\lambda \mathbb{E}_{x' \sim p_{x'}(x')}(||\nabla_{x'} C(x')||_2 - 1)^2

本物のデータと生成データの間の任意の点 $x'$ は、本物の画像データと Generator が生成した画像データをランダムな割合でブレンドした画像で表現します。

x' = \epsilon x + (1 - \epsilon) G(z), \epsilon \sim U[0, 1]

また、weight clipping と gradient penalty は制約が強すぎて Generator の表現力が低くなってしまうので、本物のデータの近傍における勾配の L2ノルムを 1 に近づける正則化項を用いる DRAGAN [8] も提案されています。

結果

訓練時の各損失関数を可視化したものが下図です。横軸は繰り返し回数、縦軸は損失関数の値を表しています。Critic, Generator ともに値が非常に大きくなっています。

wgan_logging.png

訓練した Generator で生成した顔画像を下図に示します。損失関数が大きな値になっているにもかかわらず、失敗している画像もありますが、Part2 よりも見栄えの良い顔画像を生成しているように見えます。

wgan_image.png

訓練時の画像生成の変遷をアニメーションで示したのが下図です。

wgan.gif

Part2 と同様に Inception-v3 [9] で Inception Score [10] を測ってみると以下のような結果になりました。

Inception Score 2.14

参考

CNTK 206 Part C: Wasserstein and Loss Sensitive GAN with CIFAR Data

GAN : DCGAN Part1 - Scraping Web images
GAN : DCGAN Part2 - Training DCGAN model

  1. Martin Arjovsky, Soumith Chintala, and Leon Bottou, "Wasserstein GAN", arXiv preprint arXiv:1701.07875 (2017).
  2. Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer Normalization", arXiv preprint arXiv:1607.06450 (2016).
  3. Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mira, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. "Generative Adversarial Nets", Advances in neural information processing systems. 2014, pp 2672-2680.
  4. Alec Radford, Luke Metz, and Soumith Chintal. "Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks", arXiv preprint arXiv:1511.06434 (2015).
  5. Ishaan Gulrahani, Faruk Ahmed, Martin Arjovsky, Vincent DuMoulin, and Aron Courville. "Improved Training of Wasserstein GANs", Neural Information Processing Systems (NIPS). 2017, pp 5767-5777.
  6. Diederik P. Kingma and Jimmy Lei Ba. "Adam: A method for stochastic optimization", arXiv preprint arXiv:1412.6980 (2014).
  7. Martin Arjovsky and Leon Bottou, "Towards Princiled Methods for Training Generative Adversarial Networks", International Conference on Learning Representations (ICLR). 2017.
  8. Naveen Kodali, Hacob Avernethy, James Hays, and Zsolt Kira, "On Covergence and Stability of GANs", arXiv preprint arXiv:1705.07215.
  9. Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jon Shlens, and Zbigniew Wojna. "Rethinking the Inception Architecture for Computer Vision", The IEEE Conference on Computer Vision and Pattern Recognition (CVPR). 2016, pp 2818-2826.
  10. Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen, "Improved Techniques for Training GANs", Neural Information Processing Systems. 2016. pp 2234-2242.
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした