35
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

KerasAdvent Calendar 2017

Day 22

Keras による adversarial networks / adversarial training の実装

Last updated at Posted at 2017-12-24

みなさま,おはようございます,@_akisato でございます.

Keras Advent Calendar 2017 の22日目の記事として書いております.もうクリスマス当日になってしまった気がしますが...

本日は,[generative adversarial networks (GAN)][Goodfellow2014] や [virtual adversarial training (VAT)][Miyato2016] などの登場で非常に注目を集めている 敵対的学習 (adversarial training) を Keras でどのように実装するか,についてのガイドを紹介したいと思います.
[Goodfellow2014]:https://arxiv.org/abs/1406.2661
[Miyato2016]:https://arxiv.org/abs/1507.00677

この記事で紹介した実装については,github にアップしてあります.

敵対的学習とは

敵対的学習 (adversarial training) とは,学習したい機械学習モデルと,その学習を妨げる「何か」とを用意し,両者を競わせながら学習を進める枠組を指します.

この「何か」は,例えば,学習モデルに誤った予測出力をさせようとする入力であったり,学習モデルから得られる予測出力を真の正解と見分けようとする識別器であったりします.

最も有名なものは,GoodfellowらがNIPS2014で発表した [generative adversarial networks (GAN)][Goodfellow2014] でしょう.GANは,一様分布や正規分布などから生成されたランダムベクトルから何らかの画像を生成する generator と,入力された画像が generator が生成したものか本物の画像かを見分ける discriminatorを競わせながら学習することで,本物の画像に近い画像を生成する generator を学習する方法です.

GAN については,すでに数多くの解説記事が公開されていますので,詳細はそれらをご参照下さい.

敵対的学習としてもう一つ有名なのは,これまた Goodfellow らによる論文である [敵対的サンプル (adversarial examples) 生成][Goodfellow2015] です.敵対的サンプルとは,学習サンプルにごく少量の,しかし意図的なノイズを加えることで,元々の学習サンプルとは大きく異なる予測を出力してしまうサンプルを指します.[Miyanoらの virtual adversarial training (VAT)][Miyato2016] も,この敵対的サンプル生成を行う方法の一つです.
[Goodfellow2015]:https://arxiv.org/abs/1412.6572
[Miyato2016]:https://arxiv.org/abs/1507.00677

敵対的サンプル生成については,先日公開した関連記事である,敵対的サンプル生成ライブラリ cleverhans ことはじめ をご参照下さい.

敵対的学習には2種類ある

前節に記載した例の通り,敵対的学習には,大きく分けて2つのカテゴリが存在します1.前者の代表がGANであり,後者の代表がVATです.

  • 敵対的ネットワーク (adversarial network configurations)
  • 敵対的サンプル生成 (adversarial sample generation)

それぞれのカテゴリにおける実装の「肝」が異なりますので,以降では,この2つのカテゴリを区別して考え,それぞれについて keras 実装のガイドを紹介していきます.敵対的サンプル生成については,先日公開した 敵対的サンプル生成ライブラリ cleverhans ことはじめ にほとんど書いてありますので,ここでは敵対的ネットワークの実装に絞って説明したいと思います.

敵対的ネットワークの実装

まず,敵対的ネットワークのkeras実装について説明します.kerasでの敵対的なネットワーク実装において注意するべき点は,以下の2点です.これさえクリアできれば,他のモデルと大して変わりません.

  • 複数の異なる基準でのモデル更新を交互に繰り返し実行する必要がある.
  • モデルの一部を固定しつつ,残りの部分のモデルを更新する.

ここでは,敵対的ネットワークの例として,[deep convolutional generative adversarial networks (DCGAN)][Radford2016] の実装を考えます.DCGANで何ができるかやその背景については,ここここを参考にするとよいかと思います.
[Radford2016]:https://arxiv.org/abs/1511.06434

以下の説明を簡単にするために,generator を関数 g,discriminator を関数 d,generator を駆動するランダムベクトルを z,実在する画像を x とそれぞれ表記することにします.

ネットワーク構成

ネットワークそのものの実装は,通常のモデル実装とほとんど変わりません.MNIST (W28 x H28 x C1) を対象と想定した generator 及び discriminator は,例えばそれぞれ以下のように実装できます.

# generator
def generator(input_dim):
    __z = Input(shape=(input_dim,))
    reg = regularizers.l2(0.00001)
    rand_init = RandomNormal(stddev=0.02)
    # 5th fc
    __h = Dense(units=2*2*512, activation=None, kernel_initializer=rand_init,
                kernel_regularizer=reg, bias_regularizer=reg)(__z)
    __h = Reshape((2, 2, 512))(__h)
    __h = BatchNormalization(axis=-1)(__h)
    __h = Activation('relu')(__h)
    # 4th conv
    __h = Conv2DTranspose(filters=256, kernel_size=3, strides=2, padding='same',
                          activation=None, kernel_initializer=rand_init,
                          kernel_regularizer=reg, bias_regularizer=reg)(__h)
    __h = BatchNormalization(axis=-1)(__h)
    __h = Activation('relu')(__h)
    # 3rd conv
    __h = Conv2DTranspose(filters=128, kernel_size=4, strides=1, padding='valid',
                          activation=None, kernel_initializer=rand_init,
                          kernel_regularizer=reg, bias_regularizer=reg)(__h)
    __h = BatchNormalization(axis=-1)(__h)
    __h = Activation('relu')(__h)
    # 2nd conv
    __h = Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                          activation=None, kernel_initializer=rand_init,
                          kernel_regularizer=reg, bias_regularizer=reg)(__h)
    __h = BatchNormalization(axis=-1)(__h)
    __h = Activation('relu')(__h)
    # 1st conv
    __x = Conv2DTranspose(filters=1, kernel_size=3, strides=2, padding='same',
                          activation='tanh', kernel_initializer=rand_init,
                          kernel_regularizer=reg, bias_regularizer=reg)(__h)
    # return
    return Model(__z, __x, name='generator')

# discriminator
def discriminator():
    __x = Input(shape=(28, 28, 1))
    reg = regularizers.l2(0.00001)
    rand_init = RandomNormal(stddev=0.02)
    # 1st conv
    __h = Conv2D(filters=64,  kernel_size=3, strides=2, padding='same',
                 activation=None, kernel_initializer=rand_init,
                 kernel_regularizer=reg, bias_regularizer=reg)(__x)
    __h = PReLU(shared_axes=[1,2,3])(__h)
    # 2nd conv
    __h = Conv2D(filters=128, kernel_size=3, strides=2, padding='same',
                 activation=None, kernel_initializer=rand_init,
                 kernel_regularizer=reg, bias_regularizer=reg)(__h)
    __h = BatchNormalization(axis=-1)(__h)
    __h = PReLU(shared_axes=[1,2,3])(__h)
    # 3nd conv
    __h = Conv2D(filters=256, kernel_size=4, strides=1, padding='valid',
                 activation=None, kernel_initializer=rand_init,
                 kernel_regularizer=reg, bias_regularizer=reg)(__h)
    __h = BatchNormalization(axis=-1)(__h)
    __h = PReLU(shared_axes=[1,2,3])(__h)
    # 4th conv
    __h = Conv2D(filters=512, kernel_size=3, strides=2, padding='same',
                 activation=None, kernel_initializer=rand_init,
                 kernel_regularizer=reg, bias_regularizer=reg)(__h)
    __h = BatchNormalization(axis=-1)(__h)
    __h = PReLU(shared_axes=[1,2,3])(__h)
    # 5th fc
    __h = Flatten()(__h)
    __y = Dense(units=2, activation='softmax', kernel_initializer=RandomNormal(stddev=0.02),
                kernel_regularizer=reg, bias_regularizer=reg)(__h)
    # return
    return Model(__x, __y, name='discriminator')
    

ネットワークの学習

(DC)GANでは,generator と discriminator を別々に最適化します.もう少し詳しく書くと,以下のようになります.

  • Generator update
    • Discriminator $d(\cdot)$ を固定して,generator $g(\cdot)$ だけ更新する.
    • ランダムベクトル入力 $z$ から得られる discriminator 出力 $d(g(z))$ ができるだけ1に近くなるように更新する,
  • Discriminator update
    • Generator $g(\cdot)$ を固定して,discriminator $d(\cdot)$ だけ更新する.
    • ランダムベクトル入力 $z$ から得られる discriminator 出力 $d(g(z))$ ができるだけ0に近くなるように更新する.
    • 実画像入力 $x$ から得られる discriminator 出力 $d(x)$ ができるだけ1に近くなることも同時に求める.

複数の異なる基準でのモデル更新

まずは,「××を固定して,○○だけを更新」を棚上げにして,「××から得られる discriminator 出力ができるだけ○○に近くなるように更新」の部分だけ実装していきます.具体的には,以下の3項目を実装します.

  • $d(g(z))$ ができるだけ $1$ に近くなるように $g(\cdot)$ を更新,
  • $d(g(z))$ ができるだけ $0$ に近くなるように $d(\cdot)$ を更新.
  • $d(x)$ ができるだけ $1$ に近くなるように $d(\cdot)$ を更新.

いずれも2値分類の問題になっているので,相互関係を無視すれば,実装は容易です.

import keras
from keras.losses import categorical_crossentropy
from keras.metrics import categorical_accuracy
from keras.models import Model
from keras.layers import Input
import numpy as np

def GAN_part(input_dim):
    __z  = Input(shape=(input_dim,))
    __xs = generator(input_dim)(__z)
    __ys = discriminator()(__xs)
    return Model(__z, __ys)

# ...

gen_train_stage = GAN_part(100)
gen_train_optimizer = Adam(lr=0.0002, beta_1=0.5)
gen_train_stage.compile(optimizer=gen_train_optimizer, loss=categorical_crossentropy,
                        metrics=[categorical_accuracy]) 
# ...

Z_train_batch = np.asarray(np.random.uniform(-1.0, 1.0, size=(batch_size, input_dim)), dtype=np.float32)
Y_true_batch = np.ones(batch_size)
gen_loss_now, gen_acc_now = gen_train_stage.train_on_batch(Z_train_batch, Y_true_batch)

しかし,先ほどの3つの項目のうち,第1項目と第2項目は generator を利用しており,すべての項目において discriminator を利用しています.そこで,3つの項目をすべてを同時に含むように,実装を書き直すことにします.

def GAN_almost_there(input_dim):
    # inputs
    __z  = Input(shape=(input_dim,))    # input random vector
    __xt = Input(shape=(32, 32, 3))    # real image
    # generator
    gen = generator(input_dim)
    __xs = gen(__z)
    # discriminator
    dis = discriminator()
    __yt = dis(__xt)
    __ys = dis(__xs)
    # generator training stage
    gen_train_stage = Model(__z, __ys)
    # discriminator training stage
    dis_train_stage = Model([__z, __xt], [__ys, __yt])
    # return
    return gen_train_stage, dis_train_stage

# ...

gen_train_stage, dis_train_stage = GAN_almost_there(100)
gen_train_optimizer = Adam()
gen_train_stage.compile(optimizer=gen_train_optimizer, loss=categorical_crossentropy,
                        metrics=[categorical_accuracy]) 
dis_train_optimizer = Adam()
dis_train_stage.compile(optimizer=gen_train_optimizer, loss=categorical_crossentropy,
                        metrics=[categorical_accuracy]) 
# ...

X_train_batch = X[batch_begin:batch_end]
Z_train_batch = np.asarray(np.random.uniform(-1.0, 1.0, size=(batch_size, input_dim)), dtype=np.float32)
Y_true_batch = np.ones(batch_size)
Y_fake_batch = np.zeros(batch_size)
gen_loss, gen_acc = gen_train_stage.train_on_batch(Z_train_batch, Y_true_batch)
dis_loss, dis_loss_fake, dis_loss_true, dis_acc_fake, dis_acc_true = ¥
    dis_train_stage.train_on_batch([Z_train_batch, X_train_batch], [Y_fake_batch, Y_true_batch])

ここで大事なことは,以下の2点です.

  • generator と discriminator のインスタンスを先に作る.
    • もし,__yt = discriminator()(__xt)__ys = discriminator()(__xs) としてしまうと,2つの別の discriminatorができてしまい,目的のモデルが作れません.
  • generator 更新と discriminator 更新とで別の Model を作る.
    • Generator と discriminator で同じようなことをやっているように見えます(し,現時点では実際にその通りです)が,これは後で重要になります.

これで,「××ができるだけ△に近くなるように○○を更新」の部分を実装できました,

モデルの一部を固定して残りを更新

しかし,先ほど実装した GAN_almost_there では,「××を固定して,○○を更新」の部分が実現されていません.モデル更新に関して何らの指定をしていないので,与えた教師データに対して generator も discriminator も両方同時に更新してしまいます.

この問題を解決するために,モデルのパラメータ更新の可否を指定する trainable フラグを操作します.keras.layers や keras.models に含まれる各種の層やモデルは,trainable というフラグを持ちます.これを model.trainable = False のように設定すると,model のモデルパラメータが学習中に更新されなくなります.

dis = discrminator()
dis.trainable = False    # training disabled for this layer
dis.trainable = True     # training enabled for this layer

しかし,この trainable の設定においては,以下の2点に注意する必要があります.(参考: https://qiita.com/t-ae/items/236457c29ba85a7579d5

  • trainable の設定は,当該 layer もしくは model のみで有効で,その構成要素である layer や model には適用されない.
  • trainable の設定は,model を compile するまで効力を発揮しない.

まず1点目ですが,例えば 直前の例で dis.trainable = False とした場合,discriminator そのものは trainable = False になるのですが,その中に含まれる 例えば Conv2D などの各層の trainable までは変更されません.そのため,あるモデル全体を trainable = False にしたい場合には,以下のようにして,再帰的に traiable の設定をしなければなりません.

def set_trainable(model, trainable=False):
    model.trainable = trainable
    try:
        layers = model.layers
    except:
        return
    for layer in layers:
        set_trainable(layer, trainable)

dis = discriminator()
set_trainable(dis, trainable=False)    # training disabled for this model
set_trainable(dis, trainable=True)     # training enabled for this model

2点目も含めて考慮すると,GANのモデル実装は以下のようになります.まずモデル構造と入出力関係を定義し,その後に trainable 設定を行い,最後に compile する,という手順になります.

def GAN(input_dim):
    # inputs
    __z  = Input(shape=(input_dim,))
    __xt = Input(shape=(32, 32, 3))
    # generator
    gen = generator(input_dim)
    __xs = gen(__z)
    # discriminator
    dis = discriminator()
    __yt = dis(__xt)
    __ys = dis(__xs)
    # generator training stage
    set_trainable(gen, True)
    set_trainable(dis, False)
    gen_train_stage = Model(__z, __ys)
    gen_train_optimizer = Adam()
    gen_train_stage.compile(optimizer=gen_train_optimizer, loss=categorical_crossentropy, metrics=[categorical_accuracy])
    # discriminator training stage
    set_trainable(gen, False)
    set_trainable(dis, True)
    dis_train_stage = Model([__z, __xt], [__ys, __yt], name='dis_train_stage')
    dis_train_optimizer = Adam()
    dis_train_stage.compile(optimizer=dis_train_optimizer, loss=categorical_crossentropy, metrics=[categorical_accuracy])
    # return
    return gen_train_stage, dis_train_stage

これで,所望の動作をする GAN を実装できました.全体の実装は ここ に置きました.この実装を実行すると,だいたいこんな感じの画像が得られます.MNISTは簡単なデータなので,あっという間に読める文字が生成されます.

1エポック終了後
generated_images_00000.png

5エポック終了後
generated_images_00005.png

10エポック終了後
generated_images_00010.png

100エポック終了後
generated_images_00100.png

  1. 命名は以降の説明のための便宜上のもので,何らかの文献に基づくものではありません,ご了承下さい.

35
24
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
35
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?