GAN(Generative Adversarial Networks)を学習させる際の14のテクニック

  • 70
    いいね
  • 0
    コメント

※この記事は"How to Train a GAN?" at NIPS2016を、ここを参考にして、私なりに解釈して、重要そうな部分を引用して翻訳したものです。役に立つことが多かったので共有致します。
※GANの説明は省略します。G=Generator、D=Discriminatorとして説明します。
※それぞれのテクニックに根拠はあまり書いてないですが、ほとんどが論文に書いてあった「こうすればうまくいった」というものです。GANの学習がうまくいかないときに試してみると良いと思います。

    1. 入力を正規化 (Normalize the inputs)

    ・DのInput(=Gの出力)となる画像を[-1,1]の範囲に正規化する。
    ・Gの出力が[-1,1]となるように、OutputのところをTanhにする。

    2. GのLoss関数を修正する (A modified loss function)

    ・Gの最適化では、min(log 1-D)がLossとなるが、実際にはmax(logD)をLossにした方が良い。 → 前者のLossでは勾配消失が生じやすい。Goodfellow et. al (2014)

    3. Zはガウス分布から (Use a spherical Z)

    ・GのInputとなるベクトルZは、一様分布ではなく、正規分布からサンプルする
    ・Interpolationする時は、点AからBの直線上の点ではなく、大きな円(Great Circle)上の点を用いて行う

    こちらのコードが参考になる

    4. BatchNormalization (Batch Norm)

    ・Real画像とFake画像(=Generated)で、別々のミニバッチを作る。= 一つのミニバッチに、RealとFakeが混ざらないようにする。
    batchmix.png
    ・Batchnormを使わないときは、Instance Normalizationを使う。(全てのサンプルから、平均を引いて、標準偏差で割る)

    5. ReLUやMaxPoolingのように、勾配がスパースになるものは避ける(Avoid Sparse Gradients: ReLU, MaxPool)

    ・スパースな勾配があると、GANの学習が安定しにくくなる。
    ・GとDの両方において、LeakyReLUが有効。
    ・ダウンサンプリングするときは、Average PoolingやStrideありのConv2dを使う。
    ・アップサンプリングするときは、PixelShuffleやStrideありのConvTranspose2dを使う。

    6. DのOutputの正解ラベル(Real/Fake)には、ノイジーなラベルを使う (Use Soft and Noisy Labels)

    ・普通はReal=1、Fake=0とするが、Real=0.7~1.2、Fake=0.0〜0.3からランダムにサンプルする。(Label Smoothing) Salimans et. al. 2016
    ・学習の際、たまにRealとFakeのラベルを入れかえる

    7. DCGANやハイブリッドモデル(DCGAN / Hybrid Models)

    ・DCGANが使える状況であれば使う。
    ・DCGANが使えないのであれば、VAE+GANやKL+GANのような、ハイブリッドモデルを使う。

    8. 強化学習における安定性のテクニックを使う (Use stability tricks from RL)

    ・過去のFake画像を保存しておいて、それを時々、DiscriminatorのInputとして用いたり、過去のGとDのモデルを保存しておき、数Iterationだけ、そのモデルを使って学習を行う。(Experience Replay)
    ・Deep Deterministic Policy Gradientsにおけるテクニックを使う。(詳細省略します)
    ・Pfau & Vinyals (2016)が参考になる。

    9. 最適化手法はAdamを使う (Use the ADAM Optimizer)

    ・現状だとAdamを使うのが一番良い。Radford et. al. 2015
    ・DiscriminatorにはSGD、GeneratorにはAdamを使うというのも良い。

    10. 学習がうまく進んでいるのかを確認する (Track failures early)

    ・DのLossが0になっていたら、うまくいっていない。
    ・勾配のノームが100を超えるくらい大きければ、うまくいっていない。
    ・うまくいっているときは、DのLossが小さな分散で、IterationごとにLossが小さくなっていく。大きい分散だとあまりうまくいっていない。
    ・GのLossが安定して小さくなっていっている場合、あまり意味のない画像を生成している可能性がある。

    11. DとGの学習を統計的に制御しない (Dont balance loss via statistics (unless you have a good reason to))

    ・Gの学習回数:Dの学習回数 の比率を見つけるのは、これまでに多くの人がやってきたけど、難しい。
    ・どうしても制御したいのであれば、例えば、「DのLossがAより大きければ、Dを学習」、「GのLossがBより大きければ、Gを学習」のように行う。

    12. 画像データにラベルがついているのであれば、使う (If you have labels, use them)

    ・例えば、顔画像であれば、性別や年齢のように、追加的なラベルある場合、Auxillary GANのように、Discriminatorがラベルの分類も行うようにする。

    13. InputやLayerにノイズを入れる (Add noise to inputs, decay over time)

    ・DのInputに、ノイズを入れる(ガウシアンノイズなど) 。 Arjovsky et. al., Huszar, 2016
    ・Gの全てのLayerにガウシアンノイズを入れる。Zhao et. al. EBGAN
    ・Iterationごとにノイズの大きさを減衰させる。

    14. Gに学習、テストの両段階において、Dropoutを入れる。(Use Dropouts in G in both train and test phase)

    ・50%のDropoutを、Gの一部のレイヤーに適用する。
    ・テストの段階でもDropoutを適用したままにする。
    Phillip et. al. pix2pix





Thank you so much for authors of https://github.com/soumith/ganhacks