LoginSignup
3
5

More than 3 years have passed since last update.

【ディープラーニング】DCGANを使って新種のポケモンを生成する

Posted at

はじめに

既に多くの人が取り組んでいますポケモンの自動生成に取り組んだ際のまとめです。
DCGANを使い実装しています。

GANのおさらい

GAN(Generative Adversarial Networks)はデータを生成するタイプのニューラルネットを学習するための手法です。discriminatorとgeneratorの2つを学習させて、お互いに相手の結果を学び合うようにします。

20190302143222.png

(出展)  https://medium.freecodecamp.org/an-intuitive-introduction-to-generative-adversarial-networks-gans-7a2264a81394

GANの手法はしばし贋作者と鑑定家の関係にたとえられます。discriminatorはデータが本物か偽物かを判定する鑑定家として動作する一方、generatorは本物と見分けのつかない偽物を生成します。

GANの学習が進んでいくと、贋作者は鑑定家がどのような特徴をみて判断しているかを学んでいき、一方で鑑定家も偽物を正しく見分けられるよう学んでいきます。それぞれのネットワークを競わせて行き、贋作者が本物と区別のつかないような偽物を生成できるようになれば完成です。

ポケモンの生成

学習の流れ

例えばピカチュウの偽物生成と偽物判別を例にとってみます。

2020y02m23d_221014072.png

まずノイズをgeneratorに入れ、偽物の画像を生成します。ピカチュウは可愛らしいポケモンですが、generatorの精度が十分ではないため、今回はムキムキマッチョのピカチュウが生成されてしまったとします。

その後、本物データからピカチュウの本物画像抜出し、「本物」だというラベルと一緒にdiscriminatorに入力します。一方でマッチョなピカチュウも「偽物」だというラベルと一緒にdiscriminatorに入力します。
このことにより、discriminatorはマッチョなピカチュウは偽物で、可愛らしいピカチュウは本物であると学習をします。

学習ののち、この画像の差異(loss)をgenerator側にフィードバック(backprop)します。この場合には可愛い方が本物のピカチュウで、マッチョなのは偽物であるとフィードバックをします。

2020y02m23d_221053334.png

フィードバックにより、generatorはマッチョなピカチュウは偽物だと見破られてしまうと学習し、可愛らしいピカチュウを生成するようになります。しかし可愛らしいピカチュウを生成できたものの、まだ本物との差異があります。

これを繰り返すことで、より本物に近いポケモンらしい画像を生成していけるようになります。

データセット

kaggleでポケモンのデータセットが公開されているため、そちらを利用します。

結果

  • epoch 0
    pokemon00000000.png

  • epoch 1000
    pokemon00001000.png

  • epoch 5000
    pokemon00005000.png

  • epoch 10000
    pokemon00010000.png

  • epoch 18000
    pokemon00018000.png

危険そうなポケモンがしばし見受けられます。。。
学習を繰り返すことで、輪郭を形づくることができました。しかしながら、可愛らしさがない、むしろグロテスクさのあるフォルムとなりました。またピクセル単位で画像らしく生成できているものはなく、新種のポケモンを完全自動生成できるのはまだまだ遠い未来になりそうです。

(参考)プログラム

以下を参考に実装させていただきました。

import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential, Model
from keras.layers import Conv2D, Dense, Dropout, Activation, GlobalAveragePooling2D, Input, BatchNormalization, Reshape, UpSampling2D
from keras.optimizers import Adam
from keras.layers.advanced_activations import LeakyReLU
from datetime import datetime

class DCGAN():
    def __init__(self, img_size, img_channels, ans_path , save_path, save_name):

        self.img_rows, self.img_cols = img_size
        self.channels = img_channels
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.save_path = save_path
        self.save_name = save_name
        os.makedirs(save_path,exist_ok=True)

        self.ans_path = ans_path
        self.names = os.listdir(ans_path)
        self.names.sort()
        self.X_train = []

        for name in self.names:
            img = cv2.imread(ans_path + name)
            img = cv2.resize(img , (64, 64))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            self.X_train.append(img)

        # normalization
        self.X_train = (np.array(self.X_train) / 127.5) - 1.0

        # hparams
        self.z_dim = 100
        loss="binary_crossentropy"
        optimizer = Adam(0.0002, 0.5)

        # discriminator
        self.discriminator = self.Discriminator()
        self.discriminator.compile(loss=loss,optimizer=optimizer,metrics=['accuracy'])

        # Generator
        self.generator = self.Generator()
        self.combined = self.Combined()
        self.combined.compile(loss=loss, optimizer=optimizer)


    def Generator(self):
        noise_shape = (self.z_dim,)
        model = Sequential()
        model.add(Dense(128 * 16 * 16, activation="relu", input_shape=noise_shape))
        model.add(Reshape((16, 16, 128)))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation("relu"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(3, kernel_size=3, padding="same"))
        model.add(Activation("tanh"))
        model.summary()

        return model

    def Discriminator(self):

        img_shape = (self.img_rows, self.img_cols, self.channels)
        model = Sequential()
        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(GlobalAveragePooling2D())
        model.add(Dense(2, activation='sigmoid'))
        model.summary()

        return model

    def Combined(self):
        self.discriminator.trainable = False
        model = Sequential([self.generator, self.discriminator])
        return model

    def save_imgs(self, epoch):
        r, c = 7, 10
        noise = np.random.normal(0, 1, (r * c, self.z_dim))
        gen_imgs = self.generator.predict(noise)

        # renorm
        gen_imgs = 0.5 * gen_imgs + 0.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt])
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig(self.save_path + self.save_name + str(epoch).zfill(8) +".png" )


    def train(self, epochs, batch_size, itv):
        for epoch in range(epochs):

            # discriminator
            imgs_true = np.array(self.X_train)[np.random.randint(0, len(self.X_train)-1 , batch_size//2)]
            noise = np.random.normal(0, 1, (batch_size//2, self.z_dim))

            imgs_fake = self.generator.predict(noise)
            batch_true = np.zeros([batch_size, 2])
            batch_true[:,0] = 1
            batch_false = np.zeros([batch_size, 2])
            batch_false[:,1] = 1

            d_loss_real = self.discriminator.train_on_batch(imgs_true, batch_true[:batch_size//2])
            d_loss_fake = self.discriminator.train_on_batch(imgs_fake, batch_false[:batch_size//2])


            # generator
            noise = np.random.normal(0, 1, (batch_size, self.z_dim))
            valid_y = np.array([1] * batch_size)
            g_loss = self.combined.train_on_batch(noise, batch_true)

            if epoch % 10 ==0:
                print("epoch=%d:  (discriminator) d_loss_real: %f, d_loss_fake: %f (generater)leaning loss: %f" \
                    % (epoch, d_loss_real[0], d_loss_fake[0], g_loss))

            if epoch % itv == 0:
                self.save_imgs(epoch)

if __name__ == '__main__':
    jobid=datetime.now().strftime("%Y/%m/%d%H%M%S").replace("/","")

    # param 
    ans_path="./data/pokemon/images/images/"
    save_path="./results/"
    save_name="pokemon"
    image_width=64
    image_height=64
    img_channels=3
    epochs=1000000
    batch_size=32
    iter_out=1000

    save_path=save_path+str(jobid)+"/"
    img_size=(image_width,image_height)

    # exe
    gan = DCGAN(img_size,img_channels,ans_path, save_path, save_name)
    gan.Generator()
    gan.Discriminator()
    gan.Combined()
    gan.train(epochs ,batch_size,iter_out)

3
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
5