はじめに
Wasserstein Generative Adversarial Networks
という論文を読んだのでまとめていきます。なお、今回は個人的な事情から2章までのみをまとめます。
1. 対象とするトピック
1.1 キーワード
Wasserstein distance
, GAN
arxiv
pmlr
1.2 Abstract
We introduce a new algorithm named WGAN, an alternative to traditional GAN training. In this new model, we show that we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning
curves useful for debugging and hyperparameter searches. Furthermore, we show that the corresponding optimization problem is sound, and provide extensive theoretical work highlighting the deep connections to different distances between distributions.
2. Introduction
この章では確率分布の教師なし学習を行う上での手法の解説と、実用において望ましい性質が示されています。
確率分布の教師なし学習を行う際には、以下のような問題を考えることが多いです(This is often ~ 以降)。
上記の最大化問題は、下記の真の分布とパラメータが定めるモデル分布との間の Kullback-Leibler divergence
の最小化問題と等しいです。
この問題を解けば真の分布に近いモデル分布を定めるために必要なパラメータが手に入ります。しかし、一般的にはモデル分布と真の分布との間で Kullback-Leibler divergence
を定義することができません。この Kullback-Leibler divergence
を定義するために、モデル分布にノイズを加えることが多いです。すると、モデル分布と真の分布で重なりが生じ、 Kullback-Leibler divergence
を定義することができます。しかし、ノイズのせいでモデルの質が下がるという別の問題が生じます。
そこで、更に別の手法を考える必要があります。今度はサンプルを生成するアプローチを考えます。具体的には以下の通りです。
上記の通り、ある分布に従うランダム変数からモデル分布に従うようなサンプルを生成します。その後モデルのパラメータを変化させながらモデル分布を真の分布に近づけます。この手法はVAE
やGAN
で使用されています。GAN
はVAE
よりも柔軟性で優れています。しかし、GAN
は学習が不安定で扱いが難しいという欠点を抱えています。そこでこの論文では、上述した Kullback-Leibler divergence
を含む様々な距離関数について考察を行い、GANの抱える問題を解消しようと試みています。具体的には、現在GAN
で用いられているJensen-Shannon divergence
ではなくWasserstein divergence
が距離関数として適切であると示しています。そして、Wasserstein divergence
を取り入れたWGAN
について実験等を行なっています。
なお、論文では距離関数の収束性と連続性に注目しています。そして、これらが優れている関数が距離関数として適切であるとしています。収束性についての記述は以下の通りです。
簡単な説明は以下の通りです。分布の学習において必要なのは損失関数を勾配法等で連続的に最適化することです。損失関数の収束性と連続性が満たされていない場合、勾配法によって損失関数を最適化することが難しくなってしまいます。そして、損失関数は空間における分布間の距離計算方法に依存しています。そのため、距離関数が収束性と連続性を満たす必要があります。
GANの学習が難しい理由
GAN
ではJS-divergence
を近似するD
とD
の生成した損失を最小化するように学習するG
が登場します。最終的にはG
が生成するモデル分布が真の分布に近づくことがゴールです。そのためには、D
が生成する損失を勾配法で連続的に最適化できなければいけません。しかし、D
の学習が進むにつれて損失関数の勾配は0に近づいていきます。そのため、D
を十分に学習したのちにG
を学習させるという手段が取れません。一方で学習が不足しているD
が生成する損失関数を用いても、近似が不十分な損失関数をG
が最小化することになります。この場合、G
は誤った損失を最小化しているので学習結果は芳しくありません。以上のトレードオフがJS-divergence
を距離函数とした場合のGAN
の学習が難しい理由です。
3. Different Distances
この章ではまず、様々な確率分布間の距離指標を列挙します。その後、EM distance
がその他の距離よりも優れている点を示します。(なお論文ではEM distance
がGANで必要となる低次元多様体でサポートされた確率分布を学習するために適していると結論付けています。そして、以降の章ではEM distance
を用いたWGAN
について考察しています。)
まず、記号の定義は以下の通りです。
この時、距離指標が以下のように定義されます。
次に具体例を通じて各距離指標の性質を確認します。最初に考えるのは一様分布の学習です。以下のように記号を定義します。
すると、以下のように距離が計算されます。
計算結果から明らかなように、$\theta \rightarrow \theta_0$へと収束した時に収束する距離はEM distance
のみです。
上記の具体例からEM distance
を距離として用いた場合のみ損失関数は連続になります。つまり、EM distance
を距離として用いた損失関数のみが勾配法によって最適化を行うことができます。よって低次元多様体上での確率分布の学習の際にEM distance
を距離として用いることが適切です。また、同様の結論がより一般化された状況でも得られます。
参考文献
- Martin Arjovsky, Soumith Chintala, and L´eon Bottou (2017). Wasserstein Generative Adversarial Networks, http://proceedings.mlr.press/v70/arjovsky17a.html
- Deep Learning JP (2017), [DL輪読会]Wasserstein GAN/Towards Principled Methods for Training Generative Adversarial Networks