5
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

DCGAN

Last updated at Posted at 2020-06-13

#DCGAN (Deep Convolutional GAN)

##DCGANとは?
前回の記事で作ったシンプルなGANの生成器と識別器の双方に関して、単純な2層フィードフォワードを用いるのではなく、畳み込みニューラルネットワークを用いたGANをDCGANと言います。

##バッチ正規化(Batch Normalization)
今回のDCGANではバッチ正規化を使用しています。詳しい説明はこちらの方の記事が大変わかりやすいです。

簡単にバッチ正規化の導入メリットだけ紹介すると

  1. 学習を早く進行させることが可能
  2. 初期値にそれほど依存しなくなる
  3. 過学習を抑制することができる

などが挙げられます。
今回の実装ではkeras.layers.BatchNormalization関数がミニバッチの計算や更新を裏でうまくやってくれています。
それでは実際にDCGANを実装していきましょう!大まかな流れは前回の記事とほぼ同じです。

##いざ実装!!!

###1.諸々import

#まずは諸々import

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

from keras.datasets import mnist
from keras.layers import Activation, BatchNormalization, Dense, Dropout, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Sequential
from keras.optimizers import Adam

###2.モデルの入力次元の設定

#モデルの入力次元の設定

img_rows = 28
img_cols = 28
channels = 1

img_shape = (img_rows, img_cols, channels)

#生成器への入力として使われるノイズベクトルの次元
z_dim = 100

###3.生成器の実装
生成器では、ノイズベクトルzから画像を生成するので、転置畳み込みを用いることになります。つまり下の図で、一番右のzベクトルから一番左の画像を生成するということですね。

スクリーンショット 2020-06-13 15.16.22.png

具体的なステップを以下にまとめます。

  1. ノイズベクトルを作り全結合層に通すことで7×7×256のテンソルに変換
  2. 転置畳み込み層によって、7×7×256を14×14×128に変換
  3. バッチ正規化を行い、Leaky ReLUを適用
  4. 転置畳み込み層により、14×14×128を14×14×64に変換。このステップでは高さと幅は不変
  5. バッチ正規化を行い、Leaky ReLUを適用
  6. 転置畳み込み層により、14×14×64を出力画像サイズ28×28×1に変換
  7. tanh関数を適用

Conv2DTransposeのパラメータに関してはこちらの記事を参考にさせていただきました。

#生成器

def build_generator(z_dim):
  model = Sequential()

  model.add(Dense(256*7*7, input_dim = z_dim))
  model.add(Reshape((7, 7, 256)))

  model.add(Conv2DTranspose(128, kernel_size=3, strides=2,padding='same'))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha=0.01))

  model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding="same"))
  model.add(Activation('tanh'))

  return model

###4.識別器の実装
識別器ではCNNでおなじみのネットワーク構造をとります。やっていることをざっくりと説明すると、画像データを入力し畳み込みを行うことで最終的にその画像が本物かどうかの確率を計算します。詳しい内容は下のコードで確認してください。
#識別器

def build_discriminator(img_shape):

  model = Sequential()

  model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding='same'))
  model.add(LeakyReLU(alpha=0.01))

  model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha=0.01))

  model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
  model.add(BatchNormalization())
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha=0.01))

  model.add(Flatten())
  model.add(Dense(1, activation="sigmoid"))

  return model

###5.DCGANのコンパイル

#DCGANコンパイル
def build_gan(generator, discriminator):

  model = Sequential()

  model.add(generator)
  model.add(discriminator)

  return model

discriminator = build_discriminator(img_shape)
discriminator.compile(loss="binary_crossentropy", optimizer=Adam(), metrics=["accuracy"])

generator = build_generator(z_dim)
discriminator.trainable = False

gan = build_gan(generator, discriminator)
gan.compile(loss="binary_crossentropy", optimizer=Adam())

###6.学習の設定

#Training

losses = []
accuracies = []
iteration_checkpoints = []

def train(iterations, batch_size, sample_interval):
  (X_train, _),(_, _) = mnist.load_data()

  X_train = X_train / 127.5 -1.0
  X_train = np.expand_dims(X_train, 3)

  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for iteration in range(iterations):

    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs = X_train[idx]

    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    d_loss_real = discriminator.train_on_batch(imgs, real)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss,  accuracy = 0.5 * np.add(d_loss_real, d_loss_fake)

    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    g_loss = gan.train_on_batch(z, real)
    if iteration == 0:
      sample_images(generator)

    if ((iteration + 1) % sample_interval == 0):

      losses.append((d_loss, g_loss))
      accuracies.append(100 * accuracy)
      iteration_checkpoints.append(iteration+1)

      print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %
                  (iteration + 1, d_loss, 100.0 * accuracy, g_loss))
      sample_images(generator)

###7.画像表示

def sample_images(generator, image_grid_rows=4, image_grid_columns=4):

  z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))
  gen_imgs = generator.predict(z)
  gen_imgs = 0.5 * gen_imgs + 0.5

  fig, axs = plt.subplots(image_grid_rows,
                           image_grid_columns,
                           figsize=(4,4),
                           sharey=True,
                           sharex=True
                           )
  cnt = 0
  for i in range(image_grid_rows):
    for j in range(image_grid_columns):
      axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
      axs[i, j].axis('off')
      cnt += 1

###8. いざ学習!

iterations = 20000
batch_size = 128
sample_interval = 1000
train(iterations, batch_size, sample_interval)

##結果
↓初期のノイズ
スクリーンショット 2020-06-13 19.26.47.png
↓1000iterations
スクリーンショット 2020-06-13 19.26.57.png
↓10000iterations
スクリーンショット 2020-06-13 19.27.29.png
↓20000iterations
スクリーンショット 2020-06-13 19.27.43.png

どうでしょうか、データセットのmnistから取ってきた本物の手書き文字と見分けがつかないレベルの画像を生成することができました。また、前回のシンプルなGANでは、画像にピクセル単位のノイズが入ってしまっていましたが、今回DCGANを用いることでピクセル間の関係性を埋め込むことができ、ノイズのない綺麗な画像が生成されました。

5
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?