#Conditional GAN とは??
DCGANでは、mnistデータを学習に用いることで、綺麗な手書き文字の生成に成功しました。しかしながら、この生成器を実際に用いようとなると用途が限られてしまします。なぜなら、例えば「7」と言う手書き文字を作りたいと思っても、DCGANだと生成する文字を指定できないので、7が偶然生成されるまで生成器を動かし続ける必要があるからです。
ここでは、Conditional GANを導入することで、「7」を作りたい!と思った時に一発で「7」を生成できる生成器、つまり生成するクラスを指定できる生成器を実装していきます。
###CGNAの生成器
生成器は、乱数ベクトルzとラベルyを入力とし、偽のサンプルx*|yを生成する。生成器は与えられたラベルに対応し、できる限りリアルな見た目になるようにする。
じゃあ具体的にどうやって乱数ベクトルzとラベルyの二つを同時に入力として組み込むのか。それは後述「生成器の実装」をご覧ください。
###CGANの識別器
識別器の入力は、本物のサンプルとそのラベル(x,y)の組もしくは、偽のサンプルとそれを作るために使われたラベルの組み(x*|y,y)のどちらかである。識別器は入力の組みが本物かどうかを示す確率を、シグモイド関数σによって計算し出力する。
##実装!
今回はKerasのfunctional APIとSequential APIの両方をコラボして実装している。
###1. import
###import 文
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from keras.datasets import mnist
from keras.layers import Activation, BatchNormalization, Concatenate, Dense, Embedding, Flatten, Input, Multiply, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Model, Sequential
from keras.optimizers import Adam
###2. 入力次元の設定
#諸々の入力次元の設定
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
z_dim = 100
num_classes = 10
###3.CGANの生成器
CGANの実装の際の注意点。
ラベルとランダムベクトルzの複合表現を生成器への入力とする。
具体的な手順は、まずラベルをzと同じサイズのベクトルへと埋め込む(Embedding層を利用)。次にこの埋め込みラベルとzを掛け合わせる。その結果、複合表現が得られ、それを生成器への入力とする。以上に注意して実装する。
#CGANの生成器
def build_generator(z_dim):
model = Sequential()
model.add(Dense(256 * 7 * 7, input_dim = z_dim))
model.add(Reshape((7, 7, 256)))
model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.01))
model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.01))
model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding="same"))
model.add(Activation("tanh"))
return model
def build_cgan_generator(z_dim):
z = Input(shape=(z_dim, ))
label = Input(shape=(1, ), dtype="int32")
#Embedding層を用いて、ラベルの埋め込みを行う
#ラベルをz_dim次元の密ベクトルに変換する
label_embedding = Embedding(num_classes, z_dim, input_length=1)(label)
label_embedding = Flatten()(label_embedding)
#ベクトルzと、ラベルが埋め込まれたベクトルの、要素ごとの掛け算を行う
joined_embedding = Multiply()([z, label_embedding])
generator = build_generator(z_dim)
conditioned_img = generator(joined_embedding)
return Model([z, label], conditioned_img)
###4.CGANの識別器
識別器を実装する際の注意点。
まず、ラベルを画像の画素数と同じサイズ(784=28×28×1)のベクトルに埋め込む。次に、埋め込んだラベルが入力画像(28×28×1)と同じshapeになるようにReshapeする。ラベルが埋め込まれされに変形されたテンソルを、対応する画像に連結する。この複合表現を識別器の入力とする。よって、入力の次元は(28×28×2)となることに注意!
#CGANの識別器
def build_discriminator(img_shape):
model = Sequential()
model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=(img_shape[0], img_shape[1], img_shape[2] + 1), padding="same"))
model.add(LeakyReLU(alpha=0.01))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.01))
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.01))
model.add(Flatten())
model.add(Dense(1, activation="sigmoid"))
return model
def build_cgan_discriminator(img_shape):
img = Input(shape=img_shape)
label = Input(shape=(1, ), dtype="int32")
label_embedding = Embedding(num_classes, np.prod(img_shape), input_length=1)(label)
label_embedding = Flatten()(label_embedding)
label_embedding = Reshape(img_shape)(label_embedding)
#画像と、ラベルが埋め込まれたテンソルを結合する
concatenated = Concatenate(axis=-1)([img, label_embedding])
discriminator = build_discriminator(img_shape)
classification = discriminator(concatenated)
return Model([img, label], classification)
###5.CGANモデルの構築とコンパイル
#CGANモデルの構築とコンパイル
def build_gan(generator, discriminator):
z = Input(shape=(z_dim, ))
label = Input(shape=(1, ))
img = generator([z, label])
classification = discriminator([img, label])
model = Model([z, label], classification)
return model
discriminator = build_cgan_discriminator(img_shape)
discriminator.compile(loss="binary_crossentropy", optimizer=Adam(), metrics=["accuracy"])
generator = build_cgan_generator(z_dim)
discriminator.trainable = False
cgan = build_gan(generator, discriminator)
cgan.compile(loss="binary_crossentropy", optimizer=Adam())
###6.訓練アルゴリズム
#訓練アルゴリズム
accuracies = []
losses = []
def train(iterations, batch_size, sample_interval):
(X_train, Y_train), (_, _) = mnist.load_data()
X_train = X_train / 127.5 -1.0
X_train = np.expand_dims(X_train, axis=3)
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for iteration in range(iterations):
#識別器の訓練
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs, labels = X_train[idx], Y_train[idx]
z = np.random.normal(0, 1, (batch_size, z_dim))
gen_imgs = generator.predict([z, labels])
d_loss_real = discriminator.train_on_batch([imgs, labels], real)
d_loss_fake = discriminator.train_on_batch([gen_imgs, labels], fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
#生成器の訓練
z = np.random.normal(0, 1, (batch_size, z_dim))
labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)
g_loss = cgan.train_on_batch([z, labels], real)
if (iteration+1) % sample_interval == 0:
print("%d [D loss: %f, acc: %.2f%%] [G loss: %f]" % (iteration+1, d_loss[0], 100*d_loss[1], g_loss))
losses.append((d_loss[0], g_loss))
accuracies.append(100 * d_loss[1])
sample_images(iteration)
###7.画像の表示
def sample_images(iteration, image_grid_rows=2, image_grid_columns=5):
print("iteration : %d" % (iteration+1))
z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))
labels = np.arange(0, 10).reshape(-1, 1)
gen_imgs = generator.predict([z, labels])
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(image_grid_rows,
image_grid_columns,
figsize=(10, 4),
sharey=True,
sharex=True)
cnt = 0
for i in range(image_grid_rows):
for j in range(image_grid_columns):
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
axs[i, j].set_title("数字: %d" % labels[cnt])
cnt += 1
###8.学習開始!
iterations = 12000
batch_size = 32
sample_interval = 1000
train(iterations, batch_size, sample_interval)
以上のコードを実装すると、ラベルを指定した時に、その指定したラベルに対応する画像を生成することに成功した。これでGANの実用性はグンと上がった。
###参考文献
・https://stats.stackexchange.com/questions/270546/how-does-keras-embedding-layer-work
・https://livebook.manning.com/book/gans-in-action/chapter-8/88