Edited at

GAN(Generative Adversarial Networks)を学習させる際の14のテクニック

More than 1 year has passed since last update.

※この記事は"How to Train a GAN?" at NIPS2016を、ここを参考にして、私なりに解釈して、重要そうな部分を引用して翻訳したものです。役に立つことが多かったので共有致します。

※GANの説明は省略します。G=Generator、D=Discriminatorとして説明します。

※それぞれのテクニックに根拠はあまり書いてないですが、ほとんどが論文に書いてあった「こうすればうまくいった」というものです。GANの学習がうまくいかないときに試してみると良いと思います。


    1. 入力を正規化 (Normalize the inputs)


    ・DのInput(=Gの出力)となる画像を[-1,1]の範囲に正規化する。

    ・Gの出力が[-1,1]となるように、OutputのところをTanhにする。

    2. GのLoss関数を修正する (A modified loss function)


    ・Gの最適化では、min(log 1-D)がLossとなるが、実際にはmax(logD)をLossにした方が良い。
    → 前者のLossでは勾配消失が生じやすい。Goodfellow et. al (2014)

    3. Zはガウス分布から (Use a spherical Z)


    ・GのInputとなるベクトルZは、一様分布ではなく、正規分布からサンプルする

    ・Interpolationする時は、点AからBの直線上の点ではなく、大きな円(Great Circle)上の点を用いて行う



    こちらのコードが参考になる

    4. BatchNormalization (Batch Norm)


    ・Real画像とFake画像(=Generated)で、別々のミニバッチを作る。= 一つのミニバッチに、RealとFakeが混ざらないようにする。

    batchmix.png


    ・Batchnormを使わないときは、Instance Normalizationを使う。(全てのサンプルから、平均を引いて、標準偏差で割る)

    5. ReLUやMaxPoolingのように、勾配がスパースになるものは避ける(Avoid Sparse Gradients: ReLU, MaxPool)


    ・スパースな勾配があると、GANの学習が安定しにくくなる。

    ・GとDの両方において、LeakyReLUが有効。

    ・ダウンサンプリングするときは、Average PoolingやStrideありのConv2dを使う。

    ・アップサンプリングするときは、PixelShuffleやStrideありのConvTranspose2dを使う。

    6. DのOutputの正解ラベル(Real/Fake)には、ノイジーなラベルを使う (Use Soft and Noisy Labels)


    ・普通はReal=1、Fake=0とするが、Real=0.7~1.2、Fake=0.0〜0.3からランダムにサンプルする。(Label Smoothing) Salimans et. al. 2016

    ・学習の際、たまにRealとFakeのラベルを入れかえる

    7. DCGANやハイブリッドモデル(DCGAN / Hybrid Models)


    ・DCGANが使える状況であれば使う。

    ・DCGANが使えないのであれば、VAE+GANやKL+GANのような、ハイブリッドモデルを使う。

    8. 強化学習における安定性のテクニックを使う (Use stability tricks from RL)


    ・過去のFake画像を保存しておいて、それを時々、DiscriminatorのInputとして用いたり、過去のGとDのモデルを保存しておき、数Iterationだけ、そのモデルを使って学習を行う。(Experience Replay)

    ・Deep Deterministic Policy Gradientsにおけるテクニックを使う。(詳細省略します)

    ・Pfau & Vinyals (2016)が参考になる。

    9. 最適化手法はAdamを使う (Use the ADAM Optimizer)


    ・現状だとAdamを使うのが一番良い。Radford et. al. 2015

    ・DiscriminatorにはSGD、GeneratorにはAdamを使うというのも良い。

    10. 学習がうまく進んでいるのかを確認する (Track failures early)


    ・DのLossが0になっていたら、うまくいっていない。

    ・勾配のノームが100を超えるくらい大きければ、うまくいっていない。

    ・うまくいっているときは、DのLossが小さな分散で、IterationごとにLossが小さくなっていく。大きい分散だとあまりうまくいっていない。

    ・GのLossが安定して小さくなっていっている場合、あまり意味のない画像を生成している可能性がある。

    11. DとGの学習を統計的に制御しない (Dont balance loss via statistics (unless you have a good reason to))


    ・Gの学習回数:Dの学習回数 の比率を見つけるのは、これまでに多くの人がやってきたけど、難しい。

    ・どうしても制御したいのであれば、例えば、「DのLossがAより大きければ、Dを学習」、「GのLossがBより大きければ、Gを学習」のように行う。

    12. 画像データにラベルがついているのであれば、使う (If you have labels, use them)


    ・例えば、顔画像であれば、性別や年齢のように、追加的なラベルある場合、Auxillary GANのように、Discriminatorがラベルの分類も行うようにする。

    13. InputやLayerにノイズを入れる (Add noise to inputs, decay over time)


    ・DのInputに、ノイズを入れる(ガウシアンノイズなど) 。 Arjovsky et. al., Huszar, 2016

    ・Gの全てのLayerにガウシアンノイズを入れる。Zhao et. al. EBGAN

    ・Iterationごとにノイズの大きさを減衰させる。

    14. Gに学習、テストの両段階において、Dropoutを入れる。(Use Dropouts in G in both train and test phase)


    ・50%のDropoutを、Gの一部のレイヤーに適用する。

    ・テストの段階でもDropoutを適用したままにする。
    Phillip et. al. pix2pix







Thank you so much for authors of https://github.com/soumith/ganhacks