Python
MachineLearning
DeepLearning
Keras
GANs

[失敗談]GANで戦車画像の生成

今回はGANを使って、画像の生成をしたいと思います。

GANを使ったアニメキャラクターの生成は多く見られるので、
今回は、硬派に戦車の画像生成をしたいと思います。

まずは、結果から

結果はこのような画像が生成されました。

0052_0000.png

左上の画像は戦車に見えなくもないですが、とても成功とは言い難く、
順を追って失敗理由を探ります。ちなみに、学習はColaboratoryを使い、
1時間ほどで完了しました。

皆さんがチャレンジする際は、私のような醜態をさらさずに、是非成功させてください。

戦車画像の取得

最初は以下の記事を参考に、yahooの画像検索でデータを集めようとしました。
https://qiita.com/ysdyt/items/565a0bf3228e12a2c503

しかし、類似の戦車が混在するため、結局、手作業でティーガー戦車の画像を
50枚集めました。

そして、kerasのData Augmentationを使って100枚の水増しをしています。
つまり、学習画像は150枚です。

DCGANの事例

こちら↓の記事にあるように、DCGANで高精細な画像を生成するのが、今回の目標です。
https://qiita.com/mattya/items/e5bfe5e04b9d2f0bbd47#web%E3%83%87%E3%83%A2

triwave33さんのコードをベースにちょっと変更しながら、パラメータをいじってみました。

以下に学習条件を示します。

batch size 潜在変数の次元 出力画像サイズ
10 100 128×128

kerasで実装したGenerator↓

    def build_generator(self):
        model = Sequential()

        model.add(Dense(input_dim=(self.z_dim + CLASS_NUM), output_dim=1024)) # z=100, y=10
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(0.2))
        model.add(Dense(128*8*8))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(0.2))
        model.add(Reshape((8,8,128)))
        model.add(UpSampling2D((2,2)))
        model.add(Convolution2D(128,5,5,border_mode='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(0.2))
        model.add(Dropout(0.5))
        model.add(UpSampling2D((2,2)))
        model.add(Convolution2D(256,5,5,border_mode='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(0.2))
        model.add(UpSampling2D((2,2)))
        model.add(Convolution2D(512,5,5,border_mode='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(0.2))
        model.add(UpSampling2D((2,2)))
        model.add(Convolution2D(int(self.channels),5,5,border_mode='same'))
        model.add(Activation('tanh'))

        return model

kerasで実装したDiscriminator↓

    def build_discriminator():

        model = Sequential()

        model.add(Convolution2D(64,5,5,\
              subsample=(2,2),\
              border_mode='same',\
              input_shape=(self.img_rows,self.img_cols,(self.channels+CLASS_NUM))))
        model.add(LeakyReLU(0.2))
        model.add(Convolution2D(128,5,5,subsample=(2,2)))
        model.add(LeakyReLU(0.2))
        model.add(Convolution2D(256,5,5,subsample=(2,2)))
        model.add(LeakyReLU(0.2))
        model.add(Convolution2D(512,5,5,subsample=(2,2)))
        model.add(LeakyReLU(0.2))
        model.add(Flatten())
        model.add(Dense(256))
        model.add(LeakyReLU(0.2))
        model.add(Dropout(0.25))
        model.add(Dense(1))
        model.add(Activation('sigmoid'))

        return model

失敗原因を探る

ここからは失敗した原因を探ります。
以下、「Discriminator」をD、「Generator」をGと表記します。

Batch Normalizationについて

こちらによると、DCGANはBNをいれることによって、高精細な画像生成に成功しているとのこと。
しかし、私の場合は、BNを入れることによって、逆に悪化してしまいました。
具体的には、真っ黒な画像しか生成されませんでした。

こちら↓の記事にあるように、ちゃんとFakeとTrue画像を分けて学習させたのですが、
うまくいかなかったです。
https://qiita.com/underfitting/items/a0cbb035568dea33b2d7

また、こちらでも採用されているようにInstance Normalizationも入れてみましたが
悪化してしまいました。
https://qiita.com/t-ae/items/39daefcdbe8bf927e4f3

活性化関数について

同じく、こちらの記事↓を参考に、Dの活性化関数を「Relu→LeakyRelu」に変更しました。
若干の改善が見られました。
https://qiita.com/underfitting/items/a0cbb035568dea33b2d7

ガウスノイズについて

同じく、こちら↓の記事を参考に、Dにガウスノイズを加えてみたり、潜在変数を一様分布から
ガウスノイズに変更してみたりしましたが、改善は見られませんでした。
https://qiita.com/underfitting/items/a0cbb035568dea33b2d7

CNNの層について

GのCNN層について、層の数を2層→4層に変更してみました。
若干の改善が見られました。

画像の編集について

戦車の画像に写っている背景(建物や山、人など)を手作業で消して、戦車のみを抽出しました。
かなりの改善が見られました。

水増し数について

kerasのData Augmentationの水増し数を変更してみました。(100枚→1000枚)
改善は見られませんでした。

Adamのパラメータについて

今回は、Adamのパラメータ調整で、ほとんどの時間を使ったといっても過言ではありません。
これを調整するかどうかで、結果がまるっきり変わってきます。当然、改善幅は大きかったです。

一応、ここで使ったパラメータを示しますが、これは条件により異なってくると思います。

discriminator_optimizer = Adam(lr=.5e-5, beta_1=0.3)
generator_optimizer = Adam(lr=.8e-4, beta_1=0.5)

まとめ

上記の結果をまとめると、以下のとおりです。

手段 Batch Normalization 活性化関数 ガウスノイズ CNNの層 画像の編集 水増し数 Adamの調整
効果 ×

×:悪化、-:効果なし、△:若干効果あり、〇:効果あり

今回、失敗した一番の原因はデータ数だと考えています。
基となる学習画像が50枚というのは、ちょっと少なかったです。
Data Augmentationで水増しすれば何とかなると思っていました。

再度チャレンジするかは考え中です・・・