2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

conditionalGANの実装

Posted at

cGANで目的の画像を生成する

GANでの画像生成では、ランダムな画像が生成される。今回だとデータセットにmnistを使うが、0~9のうちどの数字が生成されるかはコントロールすることができない。
mnist_25000.png

cGANの大まかな仕組み

GANではノイズをそのままgeneratorに入力して、出てきた画像をそのままdiscriminatorに入れて学習させていた。
cGANはノイズと一緒にラベルをgeneratorに入力し、出てきた画像をノイズと一緒に入れたラベルと同じラベルをdiscriminatorに入力するだけである。ラベルの入力については、modelの途中から入れる方法もあるが、
今回modelをあまり複雑にしたくなかったのラベルの入力はノイズや画像に組み合わせることにした。
ノイズにはラベルをonehotに変換しつなぎ合わせる。画像へのラベル付けはonehotではできないので、
すべての値が0の28x28の画像を10枚重ね合わせshape(28, 28, 10)に整形してラベルの値番目の画像を1で塗りつぶす。
そしてできたラベルと画像を重ね合わせdiscriminatorへ入力する。discriminatorは入力された画像がラベルの数字でありなおかつ本物かどうかを見分けるように学習していき、generatorはdiscriminatorを騙すように学習していく。

実装

まずはmodelから。
活性化関数はgenerator、discriminatorともにLeakyReLUを使用、これは入力の値が-1~1の間であり0以下の入力をreluで消さないためである。

class GAN:
    def __init__(self):
        self.imgShape = (28, 28, 11)
        self.noiseDim = 128

        self.discriminator = self.buildDiscriminator()
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=Adam(lr=1e-5, beta_1=0.1),
                                   metrics=['accuracy'])

        self.generator = self.buildGenerator()
        self.combined = self.buildCombined()
        self.combined.compile(loss='binary_crossentropy',
                              optimizer=Adam(lr=8e-4, beta_1=0.5))


    def buildGenerator(self):
        model = Sequential()
        model.add(Dense(input_dim=(self.noiseDim + 10), units=1024)) # z=100, y=10
        model.add(BatchNormalization())
        model.add(LeakyReLU(0.2))
        model.add(Dense(128*7*7))
        model.add(BatchNormalization())
        model.add(LeakyReLU(0.2))
        model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
        model.add(UpSampling2D())
        model.add(Conv2D(64, 5, padding='same'))
        model.add(BatchNormalization())
        model.add(LeakyReLU(0.2))
        model.add(UpSampling2D())
        model.add(Conv2D(1, 5, padding='same'))
        model.add(Activation('tanh'))
        return model

    def buildDiscriminator(self):
        model = Sequential()
        model.add(Conv2D(64, kernel_size=5, strides=2, input_shape=self.imgShape, padding="same"))
        model.add(LeakyReLU(0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=5, strides=2, padding="same"))
        model.add(LeakyReLU(0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(LeakyReLU(0.2))
        model.add(Dropout(0.5))
        model.add(Dense(1, activation='sigmoid'))
        img = Input(shape=self.imgShape)
        validity = model(img)
        return Model(img, validity)

    def buildCombined(self):
        noise = Input(shape=(self.noiseDim + 10,))
        label = Input(shape=(28, 28, 10,))

        fakeImage = self.generator(noise)
        fakeImage = Concatenate(axis=3)([fakeImage, label])

        self.discriminator.trainable = False
        valid = self.discriminator(fakeImage)

        model = Model(input=[noise, label], output=valid)
        return model

学習部分

    def labelToImage(self, label):
        channel = np.zeros((28, 28, 10))
        channel[:, :, label] += 1
        return channel

    def conbiOnehot(self, noise, label):
        oneHot = np.eye(10)[label]
        return np.concatenate((noise, oneHot), axis=1)


    def train(self, epochs, batchSize):
        (X, Y), (_, _) = fashion_mnist.load_data()
        X = (X.astype(np.float32) - 127.5) / 127.5
        X = X.reshape([-1, 28, 28, 1])

        discriminator = self.buildDiscriminator()
        d_opt = Adam(lr=1e-5, beta_1=0.1)
        discriminator.compile(loss='binary_crossentropy', optimizer=d_opt, metrics=['accuracy'])

        g_opt = Adam(lr=.8e-4, beta_1=0.5)
        self.combined.compile(loss='binary_crossentropy', optimizer=g_opt)

        halfBatch = int(batchSize / 2)
        for epoch in range(epochs + 1):
            noise = np.random.normal(-1, 1, (halfBatch, self.noiseDim))
            labelNoise = np.random.randint(0, 10, halfBatch)
            noise = self.conbiOnehot(noise, labelNoise)
            fakeImage = self.generator.predict(noise)
            labelImage = np.array([self.labelToImage(i) for i in labelNoise])
            fakeImage = np.concatenate((fakeImage, labelImage), axis=3)

            index = np.random.randint(0, X.shape[0], halfBatch)
            realImage, realLabel = X[index], Y[index]
            labelImage = np.array([self.labelToImage(i) for i in realLabel])
            realImage = np.concatenate((realImage, labelImage), axis=3)

            x = np.concatenate((realImage, fakeImage))
            y = np.array([1]*halfBatch + [0]*halfBatch)
            disLoss = self.discriminator.train_on_batch(x, y)

            noise = np.random.normal(-1, 1, (batchSize, self.noiseDim))
            randomLabel = np.random.randint(0, 10, batchSize)
            noise = self.conbiOnehot(noise, randomLabel)
            randomImage = np.array([self.labelToImage(i) for i in randomLabel])
            genLoss = self.combined.train_on_batch([noise, randomImage], np.ones(batchSize))

            print("epoch:%d,  [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, disLoss[0], 100 * disLoss[1], genLoss))
            if epoch % 1000 == 0:
                self.save_imgs(epoch)

        self.generator.save('mnist_generator.h5')

    def save_imgs(self, epoch):
        r, c = 10, 10
        label = [i for i in range(10)] * 10

        noise = np.random.uniform(-1, 1, (r * c, self.noiseDim))
        noehot = np.array([np.eye(10)[i] for i in label])
        noise = np.concatenate((noise, noehot), axis=1)
        genImgs = self.generator.predict(noise)

        genImgs = 0.5 * genImgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(genImgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig("images/cGAN/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=10000, batchSize=64)

で、出てきたのがこちら。
mnist_0To10000.gif
時折崩れてはいますが、狙った数字を出力することができています。

cifar10でもやってみた

mnistでも出来たしcifarでもやってみようかな、ってことでやってみた。
cifar10で学習するにあたって、noiseサイズを1024に拡大、generator、discriminatorを少し整形した。

    def buildGenerator(self):
        model = Sequential()
        model.add(Dense(input_dim=(self.noiseDim + 10), units=2048))  # z=100, y=10
        model.add(BatchNormalization())
        model.add(LeakyReLU(0.2))
        model.add(Reshape((4, 4, 128), input_shape=(128 * 4 * 4,)))
        model.add(UpSampling2D((2, 2)))
        model.add(Conv2D(64, (5, 5), padding='same'))
        model.add(BatchNormalization())
        model.add(LeakyReLU(0.2))
        model.add(UpSampling2D((2, 2)))
        model.add(Dropout(0.25))
        model.add(Conv2D(32, (5, 5), padding='same'))
        model.add(BatchNormalization())
        model.add(LeakyReLU(0.2))
        model.add(UpSampling2D((2, 2)))
        model.add(Conv2D(3, (5, 5), padding='same'))
        model.add(Activation('tanh'))
        return model

    def buildDiscriminator(self):
        model = Sequential()
        model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(self.imageShape)))
        model.add(LeakyReLU(0.2))
        model.add(Conv2D(128, (5, 5), strides=(2, 2)))
        model.add(LeakyReLU(0.2))
        model.add(Flatten())
        model.add(Dense(256))
        model.add(LeakyReLU(0.2))
        model.add(Dropout(0.5))
        model.add(Dense(1))
        model.add(Activation('sigmoid'))
        return model

イテレーションを100000回しましたが手書き数字のようにうまく形にならず、なんとなくその形に見える程度になった。
下のgifは10000イテレーションごとの結果で左から、飛行機、車、鳥、猫、鹿、犬、カエル、馬、船、トラックである。
cifar_0To10000.gif

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?