LoginSignup
4
12

More than 3 years have passed since last update.

conditional GAN

Last updated at Posted at 2020-06-18

Conditional GAN とは??

DCGANでは、mnistデータを学習に用いることで、綺麗な手書き文字の生成に成功しました。しかしながら、この生成器を実際に用いようとなると用途が限られてしまします。なぜなら、例えば「7」と言う手書き文字を作りたいと思っても、DCGANだと生成する文字を指定できないので、7が偶然生成されるまで生成器を動かし続ける必要があるからです。

ここでは、Conditional GANを導入することで、「7」を作りたい!と思った時に一発で「7」を生成できる生成器、つまり生成するクラスを指定できる生成器を実装していきます。

CGANの全体アーキテクチャ

image.png

CGNAの生成器

image.png
生成器は、乱数ベクトルzとラベルyを入力とし、偽のサンプルx*|yを生成する。生成器は与えられたラベルに対応し、できる限りリアルな見た目になるようにする。
じゃあ具体的にどうやって乱数ベクトルzとラベルyの二つを同時に入力として組み込むのか。それは後述「生成器の実装」をご覧ください。

CGANの識別器

image.png
識別器の入力は、本物のサンプルとそのラベル(x,y)の組もしくは、偽のサンプルとそれを作るために使われたラベルの組み(x*|y,y)のどちらかである。識別器は入力の組みが本物かどうかを示す確率を、シグモイド関数σによって計算し出力する。

実装!

今回はKerasのfunctional APIとSequential APIの両方をコラボして実装している。

1. import


###import 文

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

from keras.datasets import mnist
from keras.layers import Activation, BatchNormalization, Concatenate, Dense, Embedding, Flatten, Input, Multiply, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Model, Sequential
from keras.optimizers import Adam

2. 入力次元の設定

#諸々の入力次元の設定
img_rows = 28
img_cols = 28
channels = 1

img_shape = (img_rows, img_cols, channels)

z_dim = 100
num_classes = 10

3.CGANの生成器

CGANの実装の際の注意点。
ラベルとランダムベクトルzの複合表現を生成器への入力とする。
image.png

具体的な手順は、まずラベルをzと同じサイズのベクトルへと埋め込む(Embedding層を利用)。次にこの埋め込みラベルとzを掛け合わせる。その結果、複合表現が得られ、それを生成器への入力とする。以上に注意して実装する。

#CGANの生成器

def build_generator(z_dim):

  model = Sequential()

  model.add(Dense(256 * 7 * 7, input_dim = z_dim))
  model.add(Reshape((7, 7, 256)))

  model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding="same"))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha=0.01))

  model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding="same"))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha=0.01))

  model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding="same"))
  model.add(Activation("tanh"))

  return model

def build_cgan_generator(z_dim):

  z = Input(shape=(z_dim, ))
  label = Input(shape=(1, ), dtype="int32")

  #Embedding層を用いて、ラベルの埋め込みを行う
  #ラベルをz_dim次元の密ベクトルに変換する
  label_embedding = Embedding(num_classes, z_dim, input_length=1)(label)
  label_embedding = Flatten()(label_embedding)

  #ベクトルzと、ラベルが埋め込まれたベクトルの、要素ごとの掛け算を行う
  joined_embedding = Multiply()([z, label_embedding])

  generator = build_generator(z_dim)
  conditioned_img = generator(joined_embedding)

  return Model([z, label], conditioned_img)

4.CGANの識別器

識別器を実装する際の注意点。
image.png
まず、ラベルを画像の画素数と同じサイズ(784=28×28×1)のベクトルに埋め込む。次に、埋め込んだラベルが入力画像(28×28×1)と同じshapeになるようにReshapeする。ラベルが埋め込まれされに変形されたテンソルを、対応する画像に連結する。この複合表現を識別器の入力とする。よって、入力の次元は(28×28×2)となることに注意!

#CGANの識別器

def build_discriminator(img_shape):

  model = Sequential()

  model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=(img_shape[0], img_shape[1], img_shape[2] + 1), padding="same"))
  model.add(LeakyReLU(alpha=0.01))

  model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha=0.01))

  model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha=0.01))

  model.add(Flatten())
  model.add(Dense(1, activation="sigmoid"))

  return model

def build_cgan_discriminator(img_shape):

  img = Input(shape=img_shape)
  label = Input(shape=(1, ), dtype="int32")

  label_embedding = Embedding(num_classes, np.prod(img_shape), input_length=1)(label)
  label_embedding = Flatten()(label_embedding)
  label_embedding = Reshape(img_shape)(label_embedding)

  #画像と、ラベルが埋め込まれたテンソルを結合する
  concatenated = Concatenate(axis=-1)([img, label_embedding])

  discriminator = build_discriminator(img_shape)
  classification = discriminator(concatenated)

  return Model([img, label], classification)

5.CGANモデルの構築とコンパイル

#CGANモデルの構築とコンパイル

def build_gan(generator, discriminator):

  z = Input(shape=(z_dim, ))
  label = Input(shape=(1, ))

  img = generator([z, label])
  classification = discriminator([img, label])

  model = Model([z, label], classification)

  return model

discriminator = build_cgan_discriminator(img_shape)
discriminator.compile(loss="binary_crossentropy", optimizer=Adam(), metrics=["accuracy"])

generator = build_cgan_generator(z_dim)
discriminator.trainable = False

cgan = build_gan(generator, discriminator)
cgan.compile(loss="binary_crossentropy", optimizer=Adam())

6.訓練アルゴリズム

#訓練アルゴリズム

accuracies = []
losses = []

def train(iterations, batch_size, sample_interval):

  (X_train, Y_train), (_, _) = mnist.load_data()
  X_train = X_train / 127.5 -1.0
  X_train = np.expand_dims(X_train, axis=3)

  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for iteration in range(iterations):

    #識別器の訓練
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs, labels = X_train[idx], Y_train[idx]

    z = np.random.normal(0, 1, (batch_size, z_dim))
    gen_imgs = generator.predict([z, labels])

    d_loss_real = discriminator.train_on_batch([imgs, labels], real)
    d_loss_fake = discriminator.train_on_batch([gen_imgs, labels], fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    #生成器の訓練
    z = np.random.normal(0, 1, (batch_size, z_dim))
    labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)

    g_loss = cgan.train_on_batch([z, labels], real)

    if (iteration+1) % sample_interval == 0:
      print("%d [D loss: %f, acc: %.2f%%] [G loss: %f]" % (iteration+1, d_loss[0], 100*d_loss[1], g_loss))
      losses.append((d_loss[0], g_loss))
      accuracies.append(100 * d_loss[1])

      sample_images(iteration)

7.画像の表示

def sample_images(iteration, image_grid_rows=2, image_grid_columns=5):

  print("iteration : %d" % (iteration+1))

  z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))
  labels = np.arange(0, 10).reshape(-1, 1)
  gen_imgs = generator.predict([z, labels])
  gen_imgs = 0.5 * gen_imgs + 0.5
  fig, axs = plt.subplots(image_grid_rows,
                            image_grid_columns,
                            figsize=(10, 4),
                            sharey=True,
                            sharex=True)
  cnt = 0
  for i in range(image_grid_rows):
    for j in range(image_grid_columns):
        axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
        axs[i, j].axis('off')
        axs[i, j].set_title("数字: %d" % labels[cnt])
        cnt += 1

8.学習開始!

iterations = 12000
batch_size = 32
sample_interval = 1000

train(iterations, batch_size, sample_interval)

出力結果

スクリーンショット 2020-06-18 13.11.54.png

↓1000iteration
スクリーンショット 2020-06-18 13.04.45.png
↓3000iteration
スクリーンショット 2020-06-18 13.06.56.png

↓5000iteration
スクリーンショット 2020-06-18 13.05.45.png

↓12000iteration
スクリーンショット 2020-06-18 13.07.22.png

以上のコードを実装すると、ラベルを指定した時に、その指定したラベルに対応する画像を生成することに成功した。これでGANの実用性はグンと上がった。

参考文献

https://stats.stackexchange.com/questions/270546/how-does-keras-embedding-layer-work
https://livebook.manning.com/book/gans-in-action/chapter-8/88

4
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
12