LoginSignup
8
13

More than 3 years have passed since last update.

kerasによる敵対的生成ネットワーク(GAN)の実装例【初学者向け】

Last updated at Posted at 2020-10-16

この記事でやったこと

- GANによるminstの画像生成
- kerasを使った実装方法を紹介

はじめに

敵対的生成ネットワーク、つまりGAN。なんだか凄い流行ってるって事はよく聞きますが、実際に自分で実装しようとなるとなかなか敷居高いですよね。

自分もこれまで重要そうな技術だなあ、とっておきながら外から見るだけで放置していました。意外とそういう人って結構多いんじゃないですかね。

そんなGANについて今回はmnist のデータを用いて実装した例を紹介します。
データやコードは「pythonによる教師なし学習」を参考にさせて頂いています。

参考にした書籍ではオブジェクト指向を使って書かれていたので、 ちょっとレベルが高かったですがかなり勉強になりました。

同じく、初学者の参考になれば幸いです。

最初に得られた結果を示しておきます。見た目にインパクトあるので。

本物
image.png

生成
mnist_14000.png

生成画像もなかには気持ち悪いくらい似ている気がしますね...
もっと長く学習ささせればもっといいものができるかもしれないです。

敵対的生成ネットワーク(GAN)ってなに?

ここでは簡単に概要を述べます。詳細はコチラの記事を参考ください。
GAN:敵対的生成ネットワークとは何か ~「教師なし学習」による画像生成
https://www.imagazine.co.jp/gan%EF%BC%9A%E6%95%B5%E5%AF%BE%E7%9A%84%E7%94%9F%E6%88%90%E3%83%8D%E3%83%83%E3%83%88%E3%83%AF%E3%83%BC%E3%82%AF%E3%81%A8%E3%81%AF%E4%BD%95%E3%81%8B%E3%80%80%EF%BD%9E%E3%80%8C%E6%95%99%E5%B8%AB/

GANを用いることで、データセットを学習して、まるで同じデータセットと同じようなデータを作成することができます。
参考記事の例では実際には存在しないベッドルームの写真をGAN を用いて生成しています。やーまったく見分けがつかないですね機械学習恐ろしい。

今回の記事ではmnistを用いるので、 手書き文字を生成していきます。 このどうやってこの手書き文字を生成しているのでしょうか。

GANでは、データを生成するモデルとデータを識別するモデルの二つがあります。
データを生成するモデルでは、 手書き文字っぽいデータを作成していきます。 そしてその作成されたデータを識別モデルで、 偽物なのか本物なのかを判断していきます。そしてその結果をもとに生成モデルを学習させて、次はより本物に近い画像を作成していきます。

簡単に言うとこれだけのモデルなんですよね。 ただ疑問に残るのは、 どうやってデータを学習させるか、どうやってデータを学習させるのか?だと思います。

データの学習

今回のモデルではデータの学習は下記のように行います。

- 生成モデルでノイズ(100*1*1)から画像(1*28*28)を生成する
- 「実際の画像」と「生成モデルで作成された画像」で識別モデルを学習させる
- 新たに生成モデルから画像を生成する。生成された画像が識別モデルで「実際の画像」と分類されるように、生成モデルと識別モデルを学習させる。

このモデルを実際に実装していきます。

ライブラリのインポート

参考図書のまんまですが、google colabでも使えるようにすこし改良しています。


'''Main'''
import numpy as np
import pandas as pd
import os, time, re
import pickle, gzip, datetime

'''Data Viz'''
import matplotlib.pyplot as plt
import seaborn as sns
color = sns.color_palette()
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import Grid

%matplotlib inline

'''Data Prep and Model Evaluation'''
from sklearn import preprocessing as pp
from sklearn.model_selection import train_test_split 
from sklearn.model_selection import StratifiedKFold 
from sklearn.metrics import log_loss, accuracy_score
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import roc_curve, auc, roc_auc_score, mean_squared_error
from keras.utils import to_categorical

'''Algos'''
import lightgbm as lgb

'''TensorFlow and Keras'''
import tensorflow as tf
import keras
from keras import backend as K
from keras.models import Sequential, Model
from keras.layers import Activation, Dense, Dropout, Flatten, Conv2D, MaxPool2D
from keras.layers import LeakyReLU, Reshape, UpSampling2D, Conv2DTranspose
from keras.layers import BatchNormalization, Input, Lambda
from keras.layers import Embedding, Flatten, dot
from keras import regularizers
from keras.losses import mse, binary_crossentropy
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.optimizers import Adam, RMSprop

from keras.datasets import mnist

sns.set("talk")

データ読み込み

データの読み込みです。minstのデータを使用します。colaboratoryでの使用を想定しています。
x_trainしか使用しないので、reshpaeと0~1の値の正規化はx_trainにしかしていません。


# 学習データとテストデータに分割したデータ
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape((60000, 28, 28, 1))
# ピクセルの値を 0~1 の間に正規化
x_train= x_train / 255.0

DCGANのクラス設計

超重要なDCGANのコードです。生成モデル、識別モデルをまとめたクラスで定義しています。
それぞれの関数の働きを簡単に記すと

generator

  • 100*1*1のベクトルを28*28*1の画像に変換するニューラルネットワーク
  • これを学習させていくことでそれっぽい画像が生成されるようになる

discriminator

  • 28*28*1の画像が本物か、偽物かを識別するニューラルネットワーク

discriminator_model

  • 識別用のニューラルネットワークをコンパイルしてモデル化

adversarial?model

  • generaorとdiscriminatorをつなげて作成されたモデル
  • このモデルで生成ネットワークを学習させる
#DCGANのクラス
class DCGAN(object):
  #初期化
    def __init__(self, img_rows=28, img_cols=28, channel=1):

        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channel = channel
        self.D = None   # discriminator
        self.G = None   # generator
        self.AM = None  # adversarial model
        self.DM = None  # discriminator model

    #生成ネットワーク
    #100*1*1の行列をデータセットの画像と同じ1*28*28にする
    def generator(self, depth=256, dim=7, dropout=0.3, momentum=0.8, \
                  window=5, input_dim=100, output_depth=1):
        if self.G:
            return self.G
        self.G = Sequential()

        #100*1*1 → 256*7*7
        self.G.add(Dense(dim*dim*depth, input_dim=input_dim))
        self.G.add(BatchNormalization(momentum=momentum))
        self.G.add(Activation('relu'))
        self.G.add(Reshape((dim, dim, depth)))
        self.G.add(Dropout(dropout))

        #256*7*7 → 128*14*14
        self.G.add(UpSampling2D())
        self.G.add(Conv2DTranspose(int(depth/2), window, padding='same'))
        self.G.add(BatchNormalization(momentum=momentum))
        self.G.add(Activation('relu'))

        #128*14*14 → 64*28*28
        self.G.add(UpSampling2D())
        self.G.add(Conv2DTranspose(int(depth/4), window, padding='same'))
        self.G.add(BatchNormalization(momentum=momentum))
        self.G.add(Activation('relu'))

        #64*28*28→32*28*28
        self.G.add(Conv2DTranspose(int(depth/8), window, padding='same'))
        self.G.add(BatchNormalization(momentum=momentum))
        self.G.add(Activation('relu'))

        #1*28*28
        self.G.add(Conv2DTranspose(output_depth, window, padding='same'))
        #各ピクセルを0~1の間の値にする
        self.G.add(Activation('sigmoid')) 
        self.G.summary()
        return self.G


    #識別ネットワーク
    #28*28*1の画像が本物かどうかを見分ける
    def discriminator(self, depth=64, dropout=0.3, alpha=0.3):
        if self.D:
            return self.D

        self.D = Sequential()
        input_shape = (self.img_rows, self.img_cols, self.channel)

      #28*28*1 → 14*14*64
        self.D.add(Conv2D(depth*1, 5, strides=2, input_shape=input_shape,padding='same'))
        self.D.add(LeakyReLU(alpha=alpha))
        self.D.add(Dropout(dropout))

      #14*14*64 → 7*7*128
        self.D.add(Conv2D(depth*2, 5, strides=2, padding='same'))
        self.D.add(LeakyReLU(alpha=alpha))
        self.D.add(Dropout(dropout))

      #7*7*128 → 4*4*256
        self.D.add(Conv2D(depth*4, 5, strides=2, padding='same'))
        self.D.add(LeakyReLU(alpha=alpha))
        self.D.add(Dropout(dropout))

        #4*4*512 → 4*4*512 ####ただしあっているか確認###
        self.D.add(Conv2D(depth*8, 5, strides=1, padding='same'))
        self.D.add(LeakyReLU(alpha=alpha))
        self.D.add(Dropout(dropout))

        #フラット化してsigmoidで分類
        self.D.add(Flatten())
        self.D.add(Dense(1))
        self.D.add(Activation('sigmoid'))

        self.D.summary()
        return self.D

    #識別モデル
    def discriminator_model(self):
        if self.DM:
            return self.DM
        optimizer = RMSprop(lr=0.0002, decay=6e-8)
        self.DM = Sequential()
        self.DM.add(self.discriminator())
        self.DM.compile(loss='binary_crossentropy', \
                        optimizer=optimizer, metrics=['accuracy'])
        return self.DM

    #生成モデル
    def adversarial_model(self):
        if self.AM:
            return self.AM
        optimizer = RMSprop(lr=0.0001, decay=3e-8)
        self.AM = Sequential()
        self.AM.add(self.generator())
        self.AM.add(self.discriminator())
        self.AM.compile(loss='binary_crossentropy', \
                        optimizer=optimizer, metrics=['accuracy'])
        return self.AM

mnist用DCGANのクラス設計

次にこれらの関数を用いてminstのデータを実際に訓練して画像を生成していきます。
train関数で画像の訓練を、plot_imagesで画像を保存していきます。

train関数では下記のような流れで実行されています。

- 訓練用のデータをノイズから生成
- 生成されたデータを識別モデルにかける。このときどのくらい上手に識別できたかをD_lossに保存。
- 生成データが本物っぽくなるようadversarial_modelで学習させる。このときにどのくらい騙せたかをA_lossに保存。

#MNISTのデータにDCGANを適用するクラス
class MNIST_DCGAN(object):
    #初期化
    def __init__(self, x_train):
        self.img_rows = 28
        self.img_cols = 28
        self.channel = 1

        self.x_train = x_train

        #DCGANの識別、敵対的生成モデルの定義
        self.DCGAN = DCGAN()
        self.discriminator =  self.DCGAN.discriminator_model()
        self.adversarial = self.DCGAN.adversarial_model()
        self.generator = self.DCGAN.generator()

    #訓練用の関数
    #train_on_batchは各batchごとに学習している。出力はlossとacc
    def train(self, train_steps=2000, batch_size=256, save_interval=0):
        noise_input = None

        if save_interval>0:
            noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])

        for i in range(train_steps):
            #訓練用のデータをbatch_sizeだけランダムに取り出す
            images_train = self.x_train[np.random.randint(0,self.x_train.shape[0], size=batch_size), :, :, :] 

            # 100*1*1のノイズをbatch sizeだけ生み出して偽画像とする
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])

            #生成画像を学習させる
            images_fake = self.generator.predict(noise)
            x = np.concatenate((images_train, images_fake))
            #訓練データを1に、生成データを0にする
            y = np.ones([2*batch_size, 1])
            y[batch_size:, :] = 0

            #識別モデルを学習させる
            d_loss = self.discriminator.train_on_batch(x, y)

            y = np.ones([batch_size, 1])
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])

            #生成&識別モデルを学習させる
            #生成モデルの学習はここでのみ行われる
            a_loss = self.adversarial.train_on_batch(noise, y)

            #訓練データと生成モデルのlossと精度
            #D lossは生成された画像と実際の画像のときのlossとacc
            #A lossはadversarialで生み出された画像を1としたときのlossとacc
            log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
            log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
            print(log_mesg)

            #save_intervalごとにデータを保存する
            if save_interval>0:
                if (i+1)%save_interval==0:
                    self.plot_images(save2file=True, \
                        samples=noise_input.shape[0],\
                        noise=noise_input, step=(i+1))

    #訓練結果をプロットする
    def plot_images(self, save2file=False, fake=True, samples=16, \
                    noise=None, step=0):
        current_path = os.getcwd()
        file = os.path.sep.join(["","data", 'images', 'chapter12', 'synthetic_mnist', ''])
        filename = 'mnist.png'
        if fake:
            if noise is None:
                noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])
            else:
                filename = "mnist_%d.png" % step
            images = self.generator.predict(noise)
        else:
            i = np.random.randint(0, self.x_train.shape[0], samples)
            images = self.x_train[i, :, :, :]

        plt.figure(figsize=(10,10))
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image, [self.img_rows, self.img_cols])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()
        if save2file:
            plt.savefig(current_path+file+filename)
            plt.close('all')
        else:
            plt.show()

終わりに

GAN、すごいですね。なんか手書き文字っぽいものが生成されるのみると気持ち悪さすら感じます。

実際の現場だと異常検知などで使えるみたいですね。

ただ、参考図書のまとめのところに「GANを使う場合、多大なる労力がかかることを覚悟してほしい」という記載がありました。その詳しい理由については記述はありませんでしたが...

一体どれだけ苦労するんだ、GAN。

最後までお読みいただきありがとうございました。

8
13
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
13