全結合型GANの問題点
前回の記事でGANの基本構造を解説しました。ここからは基本構造を少し変更して、GANの性能アップを図りたいと思います。扱うデータは前回同様、手書きデータのmnistです。
ラベル(0-9の数字)情報を入れて学習して、得られた生成モデルから任意の文字を吐き出させて遊びたいのですが、その前に生成モデル、識別モデルの性能向上を図ります。
今回扱うデータもmnistということで当然画像データです。前回、基本モデルとして紹介した(通常の)GANでは識別器、生成器ともに通常の全結合型ニューラルネットワークを用いていました。生成された画像はこんな感じでした。
生成された文字のクオリティはともかく、画像にピクセル単位のノイズが混じっていますね。(通常の)ニューラルネットでは、隣り合うピクセル間の関係性を全く考慮していません。
そもそもmnistは縦28x横28の画像データですが、前回のGANではこれを28*28=784の一列のデータとして扱っています。さらに、全結合型のニューラルネットワークは各ノード間の情報を入れ替えても、対応する重みを入れ替えてしまえば全く同じ出力が得られるので、各ノードは独立していることになります。これでは、ピクセル間の関係性を埋め込むことはできませんね。
その影響として、生成された画像は砂絵のようなノイズが発生します。人間の目からみれば、何もないところに局所的なノイズが映り込む生成画像と、ノイズが混入しない本物画像を簡単に見分けることができますし、人間が数字を書く時にこんなノイズが必要とも思いませんが、前回用いたのGANの仕組みでは、これを見分ける方法もノイズを作らずに生成する方法も持ちません。
対策
さて、前振りが長くなってしまいましたが、画像に適したニューラルネットワーク、すなわちピクセル間の関係を埋め込めるネットワークはよく知られていますね。畳み込みニューラルネットワーク(CNN)です。前回紹介した、GANの構造で、Generator, Discriminatorのネットワークだけを、CNNに変えてみましょう。Deep Convolutional GAN すなわちDCGANです。コードを以下に示します。
def build_generator(self):
noise_shape = (self.z_dim,)
model = Sequential()
model.add(Dense(1024, input_shape=noise_shape))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dense(128*7*7))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Reshape((7,7,128), input_shape=(128*7*7,)))
model.add(UpSampling2D((2,2)))
model.add(Convolution2D(64,5,5,border_mode='same'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(UpSampling2D((2,2)))
model.add(Convolution2D(1,5,5,border_mode='same'))
model.add(Activation('tanh'))
model.summary()
return model
def build_discriminator(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
model = Sequential()
model.add(Convolution2D(64,5,5, subsample=(2,2),\
border_mode='same', input_shape=img_shape))
model.add(LeakyReLU(0.2))
model.add(Convolution2D(128,5,5,subsample=(2,2)))
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dense(256))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
return model
前回の基本のGANとの違いはここだけです。generatorとdiscriminatorのネットワークをCNNを使ったものに変更しました。パーツを入れ替えただけ。新しいGANを理解する時には基本構造から何を入れ替えたのか、例えば
・ネットワーク構造なのか
・損失関数なのか
・optimizerか
・入力するデータか
・正解ラベルか
などを意識すると良いと思います。基本が大事。
さて、DCGANのネットワークについては、前回もリンクした
はじめてのGAN
で非常に詳細に説明されており、勉強になりました。
ここでは、DCGANのネットワークについては今回も深く立ち入らないでおこうと思います。
(というか、個人的にはGANは考慮する事柄が多くて、ネットワークの構造自体に思いが向きづらいです。論文を読んでいても「まぁ、表現力が高ければこれじゃなくてもいいよ」とか書いてありますし。。)
とにかく、この記事で強調したかったのは、基本構造を押さえておけば、派生形はパーツを書き換えるだけでオッケーということです。
結果
ノイズのないきれいな画像が生成されていますね。
コードはgithubにアップしています。
次回はこのDCGANを使って、zと生成画像の関係について考察してみたいと思います。