LoginSignup
38
25

More than 5 years have passed since last update.

[Survey]Adversarial Autoencoders

Last updated at Posted at 2016-04-13

Adversarial Autoencoders

この論文の目的は、autoencoderlatent code vectorが任意の分布になるような学習方法を提案することです。
adversarial autoencodersに関する論文としては以下3つが有名で各サイトで説明されているので、今回は細かい説明は省略したいと思います。

  1. Generative Adversarial Nets
  2. Adversarial Autoencoders
  3. Unsupervised Representation learning With Deep Convolutional Generative Adversarial Networks

1番の論文に関しては下記サイトがとても詳しく、Tensorflowによる実装も掲載されています。
http://evjang.com/articles/genadv1

Adversarial autoencoders

Adversarial autoencodersのブロック図と学習の流れを示したのが下記の図になります。
上段がautoencoderになります。autoencoderは入力されたデータの次元を圧縮して(encode)、元の次元戻す(decode)処理をするもので、なるべく元の入力と同じになるように重みを学習します。autoencoderlatent code vector $z$の分布$q(z)$と任意の分布$p(z)$が同じになるようにautoencoderを学習させます。$q(z)$と$p(z)$が似ているのかどうか判断するのがDiscriminatorの役割です。もし$q(z)$と$p(z)$が同じになるようにautoencoderを学習できれば、任意の分布$p(z)$から生成した$z$をDecoderに入力してあげると、$x$を入力しなくても$x$に似たデータが$z$から生成できます。

adversarialautoencoder.png

最適化の式は下記のようになります。
ここで$D$はDiscriminator modelで$G$はGenerator Model、$p_{data}$任意の分布、$p(z)$はGeneratorにより生成された$z$の分布です。
学習は下記のようにおこないます。
1. true sample(任意の分布のデータ)とfake sample(Generatorにより生成されたデータ)をうまく分離できるようにDiscriminatorを学習します。下記の式の第一項はtrue sampleのときに1と言うようになれば値が大きくなります。第二項はfake sample($G(z)$)のときに0と言うようになれば値が大きくなります。足したものを$D$に関して最大化すればtrue samplefake sampleがうまく分離できるようになります。
2. GeneratorDiscriminatorを騙せるように学習します。Discriminatorをうまく騙せると$D(G(z))$が1になります。なので下記の式を$G$に関して最小化すればDiscriminatorをうまく騙せるようになります。
image

実際にMNISTのデータを学習させる時は下記のような構造にします。基本的には変わりありませんが、Discriminatorにデータが0なのか1なのか...9なのかを示したone hot vectorを渡します。

image

Result

もうすでにAdversarial autoencodersのコードはいろいろ公開されていますが、自分でもtensorflowで実装して学習させてみました。(もはや何番煎じかわかりませんが...)
実装の際には下記のコードやサイトを参考にしました。
http://musyoku.github.io/2016/02/22/adversarial-autoencoder/
https://github.com/takerum/adversarial_autoencoder
https://www.reddit.com/r/MachineLearning/comments/3ybj4d/151105644_adversarial_autoencoders/?

implementation

ネットワーク構造及び学習パラメータは下記の通りです。

・Encoder, Decoder, Discriminatorともに3 layer network
・すべてのネットワークのhidden layerのunit数は1000
・Encoderはfirst, second layerでReLUを使用、last layerは何もなし
・Decoderはfirst, second layerでReLUを使用、last layerはsigmoidを使用。
・Discriminatorは、first, second、last layerでReLU使用
・Latent code vectorは2次元
・100 epoch
・batch size 100
・Batch Normalization,weight decayは無し
・OptimizerはAdam
・Learning rateは、autoencoderとgeneratorは0.001でdiscriminatorは0.0002
・generatorだけ学習回数を2倍にしました。(autoencoder, discriminator, generatorの順で学習させるとうまく学習が収束しなかったので、autoencoder, discriminator, generator, generatorという感じにしました。ほんとうにこれでいいのか不明ですが・・・。)
・autoencoderは最初にpre-trainingしました。

実際に学習させた結果下記のようになりました。

10 2D Gaussian

MNISTの0〜9のデータを10個の2D Gaussian分布に押し込みます。
samples.png

epoch毎の$q(z)$の分布をgif animationにしてみました。qittaは画像サイズが1M以上だと圧縮されてgif animationが動かなくなるので、160x160にResizeしました。小さすぎてなんだかよくわからなくなってしまいましたが・・・。
s_q_z_f10.gif

decoderに$p(z)$を入力した時のepoch毎の出力結果をgif animationにしたものです。

s_img_d00.gifs_img_d01.gifs_img_d02.gif
s_img_d03.gifs_img_d04.gifs_img_d05.gif
s_img_d06.gifs_img_d07.gifs_img_d08.gif
s_img_d09.gif

swiss roll

MNISTの0〜9のデータを2D swiss roll分布に押し込みます。
samples.png

epoch毎の$q(z)$の分布をgif animationにしてみました。
s_q_z_f10.gif

decoderに$p(z)$を入力した時のepoch毎の出力結果をgif animationにしたものです。
s_img_d00.gifs_img_d01.gifs_img_d02.gif
s_img_d03.gifs_img_d04.gifs_img_d05.gif
s_img_d06.gifs_img_d07.gifs_img_d08.gif
s_img_d09.gif

最後に

試行錯誤のすえ、ようやく学習が収束するようになりました。自分は色々な人の実装を参考にしながらやったので何とかなりましたが、これを最初に実装した人はすごいなと思います。実装してみた感想としては、うまく収束させるのが難しいです。ただ3つを順番に学習させればいいというわけではなさそうで、3つのバランスが重要な気がします。できたばかりでコードがぐちゃぐちゃのため整理したらGithubにあげようかなと思っています。

code

38
25
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
38
25