Help us understand the problem. What is going on with this article?

PyTorch+Google ColabでVariational Auto Encoderをやってみた

More than 1 year has passed since last update.

PyTorch+Google ColabでVariational Auto Encoderをやってみました。MNIST, Fashion-MNIST, CIFAR-10, STL10の画像を処理しました。また、Variationalではなく、ピュアなAuto EncoderをData Augmentationを使ってやってみましたが、これはあまりうまく行きませんでした。

データとネットワーク構成について

MNIST, Fashion-MNIST, CIFAR-10, STL10は以下の通りです。

MNIST
mnist.png

Fashion-MNIST
fashion-mnist.png

CIFAR-10
cifar.png

STL-10
stl.png

STL-10は馴染みが薄いかもしれませんが、スタンフォード大学が公開している教師なし学習向けのCIFARに似たなデータセットです。ただし解像度が大きく、CIFARが32x32であるのに対して、STLは96x96です。

ネットワーク構成は以下の通りです。CNNによるVAEを作ります。1x1畳み込みから構成されるボトルネック層が1層、3x3畳み込みの層が3層、計4層のエンコーダーです。ダウンサンプリングはPoolingではなくConv2dのstrideを使っています。

エンコーダー チェンネル数 カーネル MNIST/Fashion CIFAR-10 STL-10
ボトルネック 32 1 28x28 32x32 96x96
1層目 64 3 28x28 32x32 96x96
2層目 128 3 14x14 8x8 24x24
3層目 256 3 7x7 4x4 6x6

ちなみにデコーダー側はこうなります。ConvTranspose2dでkernel_size=strideとしてアップサンプリングしました。ボトルネックの最後にはシグモイド活性化関数をつけています。

デコーダー チェンネル数 カーネル MNIST/Fashion CIFAR-10 STL-10
3層目 128 stride 7x7 4x4 6x6
2層目 64 stride 14x14 8x8 24x24
1層目 32 stride 28x28 32x32 96x96
ボトルネック 3 1 28x28 32x32 96x96

また中間層は以下のようになります。

中間層 チャンネル数
エンコーダー出力 128×データごとの解像度
μ 64
σ 64
デコーダー入力 128×データごとの解像度

潜在変数は64個で表現しました。平均(μ)と対数分散(σ)は活性化関数なしの線形活性化関数です。コード全体はこちらになります。

コード:https://gist.github.com/koshian2/64e92842bec58749826637e3860f11fa

また本実装はPyTorchのVAEの例、CNNのVAEのPyTorchの例を参考にしました。各リポジトリは以下の通りです。

理論的な話

自分が説明するより詳しい説明がネットにいっぱいあるので、そちらを参照してください。直感的な説明ならKerasのブログがわかりやすいです。
Building Autoencoders in Keras
https://blog.keras.io/building-autoencoders-in-keras.html

日本語の資料ならこちらもおすすめです。
猫でも分かるVariational AutoEncoder
https://www.slideshare.net/ssusere55c63/variational-autoencoder-64515581

数式的な証明はこちらがわかりやすいです。自分は半分ぐらいしか理解できませんでしたが、ベイズの定理から始まり、潜在変数と入力データの同時確率を求めるというベイズ統計のアプローチをディープラーニングで用いるという手法です。VAEの場合は、潜在変数が標準正規分布に従うという仮定をおく(標準正規分布になるように最適化する)ことで、reparametrization trickを閉じた式で表現できるようになっています。難しいですが、非常に美しい理論展開です。
Variational Autoencoder: Intuition and Implementation
https://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/

結果

結果は以下の通りです。

各データセットにつき、1枚目の画像はテストデータをエンコーダーの入力に入れたときのデコーダー側の出力(reconstruction)、2枚目の画像は正規乱数をデコーダーの入力に通したランダムサンプリング(sampling)です。reconstruction, samplingともに最終エポックのものを表示しています。3枚目はオリジナルです(再掲)。

STL以外は200エポック、STLはデータ数が多いので処理時間の関係上100エポック訓練させました。

MNIST

reconstruction
mnist_reconstruction.png

sampling
mnist_sampling.png

かなり上手くいっています。MNISTは簡単なデータセットなので。

ground truth
mnist.png

Fashion-MNIST

reconstruction
fmnist_reconstruction.png

sampling
fmnist_sampling.png

Fashion-MNISTはMNISTより若干難しいデータセットです。輪郭のボケは見えるものの、だいたい上手くいっているように見えます。

ground truth
fashion-mnist.png

CIFAR-10

reconstruction
cifar_reconstruction.png

sampling
cifar_sampling.png

カラー画像になるとかなり厳しい印象を受けます。ネットワークをもっと深くすれば生成性能は上がるかもしれませんが、reconstructionでも相当ボケていますし、samplingはもっとひどくて雰囲気はわかるもののピンぼけが厳しいですね。

ground truth
cifar.png

STL-10

STL-10のデータ構成はtrain, testだけでなく、unlabled(ラベル付されていない)データがあります。trainが5000件、testが8000件、unlabledが10万件用意されているので、unlabledで訓練させtestで再現画像を生成させました。

reconstruction
stl_reconstruction.png

sampling
stl_sampling.png

正直ひどいですね。samplingなんかGANの失敗画像みたいな感じになっています。解像度が大きくなるとVAE特有のピンぼけさが目立つようになります。

ground truth
stl.png

学習経過

VAEのエラーはBinary Cross Entropy+KL Divergenceなので、この評価方法はひょっとしたら正しくないかもしれませんが、画像1枚あたりのエラーを画像の(ピクセル数×カラーチャンネル数)で割った値をプロットしてみます。カラーチャンネル数はモノクロなら1、カラーなら3です。

vae_train_loss.png
vae_test_loss.png

やはり生成された画像を見たとおり、カラー画像に対するエラー推移があまりよろしくないです。STLのテストなんてほとんど学習が進んでいません。

プレーンなAuto Encoder+Data Augmentationの模索

VAEがプレーンなAuto Encoderに比べて良い点は、潜在変数を標準正規分布と仮定することで、空間がスカスカにならないことです。ただその一方で、そのパラメトリックな仮定が悪さをしているのか、あるいはモデルの深さが足りないこととの相互作用なのか、解像度が高い画像に対してはぼやけるという欠点があります。

スカスカにならないのなら、例えばShake-Shakeのようなネットワークの中でData Augmentationを行ってプレーンなAuto Encoderに食わせたらどうだろう?と思ってやってみました。結果はダメでした。他にもShake-Shakeではなく、MixupのようなData Augmentationを行っても似たような失敗例になりました。

コードはこちらにあります。
https://gist.github.com/koshian2/2098e2261d673c818f6bdc51fa485e86
なお、Shake-Shakeの実装はこちらのリポジトリを使わせていただきました。
https://github.com/hysts/pytorch_shake_shake/blob/master/functions/shake_shake_function.py

ポイントだけかいつまんでいうと、

  • エンコーダー・デコーダーともに1層あたり畳み込みを2つ作り、それをShake-Shakeで結合する
  • 中間層はエンコーダー側はtanhで、デコーダー側はReLUで結ぶ。
  • 中間層からの生成については、reconstructionはそのまま、samplingのみ2つに分割し、(1)ブートストラップサンプリング、(2)ランダムサンプリングの2つのパターンを作る
  • ブートストラップサンプリングについては、直前の中間層の行列(ミニバッチサイズ×潜在変数の次元)をキャッシュし、潜在変数の次元ごとにnp.random.choiceで選んで、次元間で独立にサンプリングする(ミニバッチ単位でブートストラップすると、reconstructionと変わらないため)。つまり、このブートストラップは(ミニバッチサイズ)^(潜在変数の次元)の組み合わせがある。
  • ランダムサンプリングについては、-1~1の一様乱数で潜在変数をサンプリングする。ただし、潜在空間が一様分布に従うという保証はない(VAEの場合は、正規分布に従うという仮定がある)。

ピュアなAEの結果

MNISTとCIFARでやってみました。結果を見てみましょう。reconstructionはとてもきれいです。

reconstruction
shakeshake_mnist_reconstruction.png
shakeshake_cifar_reconstrunction.png

ただ、samplingが厳しいです。ブートストラップだと、かろうじて原型は残っていますが、何がんなんだかわからない。多分このブートストラップだと変数間の相関を無視しているからだと思います。

bootstrap sampling
shakeshake_mnist_bootstrap.png
shakeshake_cifar_bootstrap.png

そしてランダムサンプリングだとほぼノイズですね。

random sampling
shakeshake_mnist_random.png
shakeshake_cifar_random.png

あとかなり気になったのが、ピュアなAEにしたらVAEと比べてロスの値がガクッと(VAEが100ぐらいとしたら、AEが0.00…いくつ)落ちたことです。ピュアなAEではKLダイバージェンスを入れていません。ひょっとすると、VAEではロスのうち、エンコーダーの入力画像をデコーダーの出力画像を等しくするクロスエントロピーよりも、KLダイバージェンスのほうが圧倒的に支配力が強いのではないかなと思われます。不鮮明になってしまうのはひょっとしたらこれもあるかもしれません。

やはり、潜在変数の分布は仮定したほうが良いのでしょうか?その一方でDCGANをはじめGANファミリーが画像生成で素晴らしい成績を残しているので、エンコーダー・デコーダーモデルが悪いというわけではないはずです。VAEは、GANに見られるようなDiscriminator/Generatorの片方のロスが0になってモデルが崩壊するということがないので、学習が安定しているというメリットがあると思います。

なので、VAEでGAN並の高解像度な出力ができたら、個人的にはVAEを使ったほうが実用面でのメリットは今の所は大きいかなと感じています。GANの学習安定化について理論だったアプローチがあれば知りたいですし、Auto Encoderで高解像度な出力を得る方法があったらコメントで教えていただければ助かります。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away