cGANで目的の画像を生成する
GANでの画像生成では、ランダムな画像が生成される。今回だとデータセットにmnistを使うが、0~9のうちどの数字が生成されるかはコントロールすることができない。
cGANの大まかな仕組み
GANではノイズをそのままgeneratorに入力して、出てきた画像をそのままdiscriminatorに入れて学習させていた。
cGANはノイズと一緒にラベルをgeneratorに入力し、出てきた画像をノイズと一緒に入れたラベルと同じラベルをdiscriminatorに入力するだけである。ラベルの入力については、modelの途中から入れる方法もあるが、
今回modelをあまり複雑にしたくなかったのラベルの入力はノイズや画像に組み合わせることにした。
ノイズにはラベルをonehotに変換しつなぎ合わせる。画像へのラベル付けはonehotではできないので、
すべての値が0の28x28の画像を10枚重ね合わせshape(28, 28, 10)に整形してラベルの値番目の画像を1で塗りつぶす。
そしてできたラベルと画像を重ね合わせdiscriminatorへ入力する。discriminatorは入力された画像がラベルの数字でありなおかつ本物かどうかを見分けるように学習していき、generatorはdiscriminatorを騙すように学習していく。
実装
まずはmodelから。
活性化関数はgenerator、discriminatorともにLeakyReLUを使用、これは入力の値が-1~1の間であり0以下の入力をreluで消さないためである。
class GAN:
def __init__(self):
self.imgShape = (28, 28, 11)
self.noiseDim = 128
self.discriminator = self.buildDiscriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=Adam(lr=1e-5, beta_1=0.1),
metrics=['accuracy'])
self.generator = self.buildGenerator()
self.combined = self.buildCombined()
self.combined.compile(loss='binary_crossentropy',
optimizer=Adam(lr=8e-4, beta_1=0.5))
def buildGenerator(self):
model = Sequential()
model.add(Dense(input_dim=(self.noiseDim + 10), units=1024)) # z=100, y=10
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
model.add(Dense(128*7*7))
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
model.add(UpSampling2D())
model.add(Conv2D(64, 5, padding='same'))
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
model.add(UpSampling2D())
model.add(Conv2D(1, 5, padding='same'))
model.add(Activation('tanh'))
return model
def buildDiscriminator(self):
model = Sequential()
model.add(Conv2D(64, kernel_size=5, strides=2, input_shape=self.imgShape, padding="same"))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=5, strides=2, padding="same"))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(LeakyReLU(0.2))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
img = Input(shape=self.imgShape)
validity = model(img)
return Model(img, validity)
def buildCombined(self):
noise = Input(shape=(self.noiseDim + 10,))
label = Input(shape=(28, 28, 10,))
fakeImage = self.generator(noise)
fakeImage = Concatenate(axis=3)([fakeImage, label])
self.discriminator.trainable = False
valid = self.discriminator(fakeImage)
model = Model(input=[noise, label], output=valid)
return model
学習部分
def labelToImage(self, label):
channel = np.zeros((28, 28, 10))
channel[:, :, label] += 1
return channel
def conbiOnehot(self, noise, label):
oneHot = np.eye(10)[label]
return np.concatenate((noise, oneHot), axis=1)
def train(self, epochs, batchSize):
(X, Y), (_, _) = fashion_mnist.load_data()
X = (X.astype(np.float32) - 127.5) / 127.5
X = X.reshape([-1, 28, 28, 1])
discriminator = self.buildDiscriminator()
d_opt = Adam(lr=1e-5, beta_1=0.1)
discriminator.compile(loss='binary_crossentropy', optimizer=d_opt, metrics=['accuracy'])
g_opt = Adam(lr=.8e-4, beta_1=0.5)
self.combined.compile(loss='binary_crossentropy', optimizer=g_opt)
halfBatch = int(batchSize / 2)
for epoch in range(epochs + 1):
noise = np.random.normal(-1, 1, (halfBatch, self.noiseDim))
labelNoise = np.random.randint(0, 10, halfBatch)
noise = self.conbiOnehot(noise, labelNoise)
fakeImage = self.generator.predict(noise)
labelImage = np.array([self.labelToImage(i) for i in labelNoise])
fakeImage = np.concatenate((fakeImage, labelImage), axis=3)
index = np.random.randint(0, X.shape[0], halfBatch)
realImage, realLabel = X[index], Y[index]
labelImage = np.array([self.labelToImage(i) for i in realLabel])
realImage = np.concatenate((realImage, labelImage), axis=3)
x = np.concatenate((realImage, fakeImage))
y = np.array([1]*halfBatch + [0]*halfBatch)
disLoss = self.discriminator.train_on_batch(x, y)
noise = np.random.normal(-1, 1, (batchSize, self.noiseDim))
randomLabel = np.random.randint(0, 10, batchSize)
noise = self.conbiOnehot(noise, randomLabel)
randomImage = np.array([self.labelToImage(i) for i in randomLabel])
genLoss = self.combined.train_on_batch([noise, randomImage], np.ones(batchSize))
print("epoch:%d, [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, disLoss[0], 100 * disLoss[1], genLoss))
if epoch % 1000 == 0:
self.save_imgs(epoch)
self.generator.save('mnist_generator.h5')
def save_imgs(self, epoch):
r, c = 10, 10
label = [i for i in range(10)] * 10
noise = np.random.uniform(-1, 1, (r * c, self.noiseDim))
noehot = np.array([np.eye(10)[i] for i in label])
noise = np.concatenate((noise, noehot), axis=1)
genImgs = self.generator.predict(noise)
genImgs = 0.5 * genImgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(genImgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig("images/cGAN/mnist_%d.png" % epoch)
plt.close()
if __name__ == '__main__':
gan = GAN()
gan.train(epochs=10000, batchSize=64)
で、出てきたのがこちら。
時折崩れてはいますが、狙った数字を出力することができています。
cifar10でもやってみた
mnistでも出来たしcifarでもやってみようかな、ってことでやってみた。
cifar10で学習するにあたって、noiseサイズを1024に拡大、generator、discriminatorを少し整形した。
def buildGenerator(self):
model = Sequential()
model.add(Dense(input_dim=(self.noiseDim + 10), units=2048)) # z=100, y=10
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
model.add(Reshape((4, 4, 128), input_shape=(128 * 4 * 4,)))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(64, (5, 5), padding='same'))
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
model.add(UpSampling2D((2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(32, (5, 5), padding='same'))
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))
model.add(UpSampling2D((2, 2)))
model.add(Conv2D(3, (5, 5), padding='same'))
model.add(Activation('tanh'))
return model
def buildDiscriminator(self):
model = Sequential()
model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(self.imageShape)))
model.add(LeakyReLU(0.2))
model.add(Conv2D(128, (5, 5), strides=(2, 2)))
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dense(256))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
return model
イテレーションを100000回しましたが手書き数字のようにうまく形にならず、なんとなくその形に見える程度になった。
下のgifは10000イテレーションごとの結果で左から、飛行機、車、鳥、猫、鹿、犬、カエル、馬、船、トラックである。