LoginSignup
33
28

More than 5 years have passed since last update.

敵対的サンプル生成ライブラリ cleverhans ことはじめ

Last updated at Posted at 2017-12-02

みなさま,おはようございます,@_akisato でございます.

Tensorflow Advent Calender 2017 及び Deep Learning フレームワークざっくり紹介 Advent Calendar 2017 の6日目の記事として書いております.

本日は,敵対的サンプル (adversarial examples) 生成のためのライブラリとして,Tensorflow 開発陣が開発した cleverhans を紹介します.

この記事で紹介した実装については,github にアップしてあります.

(12月7日に,大幅に追記修正を行いました.それ以前にご覧になった方は,是非再度ご確認を頂けますと幸いです.)

敵対的サンプル adversarial examples とは

敵対的サンプル (adversarial examples) とは,学習サンプルにごく少量の,しかし意図的なノイズを加えることで,元々の学習サンプルとは大きく異なる予測を出力してしまうサンプルを指します.

敵対的サンプルについては,すでに数多くの優れた記事が公開されておりますので,もしご存じない方はそれらをご参照下さい.

敵対的サンプルは,元々の学習サンプルと極めて類似しているにもかかわらず,モデルによる予測が大きく異なる,という,多くの問題において望ましくない性質を持っています.この敵対的サンプルが元の学習サンプルと同じような予測を行うように(ニューラルネットワークを含む)機械学習モデルを学習させることで,所定の機械学習モデルの性能を向上させようという方法論を,敵対的学習 (adversarial training)1 と呼び,GoodfellowらによるICLR2015論文以降,わずか2年で数多くの研究がなされています.詳細は下記の解説記事をご参照下さい.

敵対的サンプル生成の実装は面倒

引用した記事を確認していただけますとわかるかもしれませんが,敵対的サンプルを生成する実装は,それほど簡単ではありません.Kerasでその例を示したいと思います.

敵対的サンプル生成を考慮しない一般的な畳み込みニューラルネットワークを以下のように実装したとします.簡潔性を重視したため,モデルはかなり適当です.

mnist_plain.py
from __future__ import absolute_import, division, print_function
import numpy as np
from sklearn import metrics
import tensorflow as tf
from keras.datasets import mnist as dataset
from keras.models import Model, Input, Dense, Conv2D, Flatten, MaxPooling2D, Dropout
from keras.backend import tensorflow_backend
from keras.utils.np_utils import to_categorical
from keras.losses import categorical_crossentropy
from keras.optimizers import Adam

def CNN_Model(input_shape, output_dim):
    __x = Input(shape=input_shape)
    __h = Conv2D(filters=32, kernel_size=3, activation='relu')(__x)
    __h = Conv2D(filters=64, kernel_size=3, activation='relu')(__h)
    __h = MaxPooling2D(pool_size=(2, 2))(__h)
    __h = Dropout(rate=0.25)(__h)
    __h = Flatten()(__h)
    __h = Dense(units=128, activation='relu')(__h)
    __h = Dropout(rate=0.25)(__h)
    __y = Dense(units=output_dim, activation='softmax')(__h)
    return Model(__x, __y)

# main
if __name__ == "__main__":
    # GPU configulations
    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    session = tf.Session(config=config)
    tensorflow_backend.set_session(session)

    # random seeds
    np.random.seed(1)
    tf.set_random_seed(1)

    # parameters
    n_classes = 10
    n_channels = 1
    img_width = 28
    img_height = 28

    # load the dataset
    print('Loading the dataset...')
    (X_train, Y_train_int), (X_test, Y_test_int) = dataset.load_data()
    X_train = X_train[:, np.newaxis].transpose((0, 2, 3, 1)).astype('float32') / 255.0
    X_test = X_test[:, np.newaxis].transpose((0, 2, 3, 1)).astype('float32') / 255.0
    Y_train = to_categorical(Y_train_int, num_classes=n_classes)
    Y_test = to_categorical(Y_test_int, num_classes=n_classes)

    # training
    print('Train a NN model...')
    ## define
    input_shape = (img_width, img_height, n_channels)
    model = CNN_Model(input_shape, n_classes)
    optimizer = Adam()
    model.compile(optimizer=optimizer, loss=categorical_crossentropy, metrics=['accuracy'])
    ## train
    history = model.fit(X_train, Y_train, batch_size=100,
                        epochs=100, shuffle=True, initial_epoch=0)

    # test
    Y_train_pred = model.predict(X_train)
    Y_train_pred = Y_train_pred.argmax(axis=1)
    Y_test_pred  = model.predict(X_test)
    Y_test_pred = Y_test_pred.argmax(axis=1)
    print('Training score for a NN classifier: \t{0}'.format(
        metrics.accuracy_score(Y_train_int, Y_train_pred)))
    print('Test score for a NN classifier: \t{0}'.format(
        metrics.accuracy_score(Y_test_int, Y_test_pred)))
    print('Training classification report for a NN classifier\n{0}\n'.format(
        metrics.classification_report(Y_train_int, Y_train_pred)))
    print('Test classification report for a NN classifier\n{0}\n'.format(
        metrics.classification_report(Y_test_int, Y_test_pred)))

このモデル学習に,Goodfellowらが提案した fast gradient sign method (FGSM) による敵対的サンプル生成を組み込むと,以下のようになります.Keras による Virtual adversarial training の実装 をだいぶ参考にさせていただいています.

この実装のポイントは,以下の4点になります.

  • 学習対象のモデルをラップする Adversarial_Training モデルを作る.(下記実装13行目から)
  • ラッピングモデルの中で,ラップされたモデルの学習損失 main loss の勾配を手動で計算する(と,その勾配に基づいて敵対的サンプルを生成できる).(下記実装30-35行目)
  • ラッピングモデルの中で生成した敵対的サンプルから計算される学習損失 adversarial loss を main loss に足し込む.(下記実装22-27,39行目)
  • 敵対的サンプルを生成する前に,keras.backend.stop_gradient を用いて,勾配計算の連鎖を止める.これをしないと,モデル更新の際に main loss の2階微分相当が追加されてしまう.(下記実装33行目)
mnist_fgsm.py
from __future__ import absolute_import, division, print_function
import numpy as np
from sklearn import metrics
import tensorflow as tf
from keras import backend as K
from keras.datasets import mnist as dataset
from keras.models import Model, Input, Dense, Conv2D, Flatten, MaxPooling2D, Dropout
from keras.backend import tensorflow_backend
from keras.utils.np_utils import to_categorical
from keras.losses import categorical_crossentropy
from keras.optimizers import Adam

class Adversarial_Training(Model):
    _at_loss = None
    # set up
    def setup_at_loss(self, loss_func=categorical_crossentropy, eps=0.25/255.0, alpha=1.0):
        self._loss_func = loss_func
        self._alpha = alpha
        self._at_loss = self.at_loss(eps)
        return self
    # loss
    @property
    def losses(self):
        losses = super(self.__class__, self).losses
        if self._at_loss is not None:
            losses += [ self._alpha * self._at_loss ]
        return losses
    # VAT loss
    def at_loss(self, eps):
        # original loss
        loss_orig = self._loss_func(self.inputs[1], self.outputs[0])
        # gradients
        grads = K.stop_gradient(K.gradients(loss_orig, self.inputs[0]))[0]
        # perterbed samples
        new_inputs = self.inputs[0] + eps * K.sign(grads)
        # estimation for the perturbated samples
        outputs_perturb = self.call([new_inputs, self.inputs[1]])
        # computing losses
        loss = self._loss_func(self.inputs[1], outputs_perturb)
        return loss

def CNN_Model_with_AT(input_shape, output_dim, loss_func=categorical_crossentropy, eps=0.25, alpha=0.5):
    # core model
    __x = Input(shape=input_shape)
    __h = Conv2D(filters=32, kernel_size=3, activation='relu')(__x)
    __h = Conv2D(filters=64, kernel_size=3, activation='relu')(__h)
    __h = MaxPooling2D(pool_size=(2, 2))(__h)
    __h = Dropout(rate=0.25)(__h)
    __h = Flatten()(__h)
    __h = Dense(units=128, activation='relu')(__h)
    __h = Dropout(rate=0.25)(__h)
    __yp = Dense(units=output_dim, activation='softmax')(__h)
    # adversarial training
    __yt = Input(shape=(output_dim,))
    return Adversarial_Training([__x, __yt], __yp).setup_at_loss(loss_func=loss_func, eps=eps, alpha=alpha)

# main
if __name__ == "__main__":
    # ...    
    # training
    print('Train a NN model...')
    ## define
    input_shape = (img_width, img_height, n_channels)
    model = CNN_Model_with_AT(input_shape, n_classes, loss_func=categorical_crossentropy)    
    # ...

モデルが簡単なのであまり大変そうに見えないかもしれませんが,少しモデルが複雑になると,今回実装した Adversarial_Training クラスはすぐに動かなくなります.また,他の種類の敵対的サンプル,例えば Miyatoらの virtual adversarial training (VAT) を試したいときには,すべてを自力で実装しなければいけません.

敵対的サンプルに基づく敵対的学習の実装の難しさは,

  • 敵対的サンプルを得るために元の入力サンプルについてのモデルの勾配を計算しなければならず,
  • この敵対的サンプルによるモデルの損失を元の入力サンプルによるモデルの損失と同時に考慮しなければいけない

という2点にあります.この困難を,敵対的サンプルの種類ごとに倒していかなければならない苦労は,あまりしたくありません.

cleverhansを導入する

cleverhansの良いところ,悪いところ

cleverhans は,

  • 先に示した実装の面倒さを解消してくれることが期待される点,
  • 実装をほとんど変更することなく異なる種類の敵対的サンプルを利用できる点

において,自分で実装するよりも優れていると言えます.難点としては,ドキュメント があまり整備されていないため,examples, tutorial, ソースなどを見ながら解決しなければならない部分がどうしても残る,という点かと思います.

cleverhansのインストール

インストールはすぐにできます.github にガイドラインが書いてあるので,それに従って下さい.Tensorflowさえインストールできれば,特に難しいところはありません.

パッケージの依存関係からもうかがい知れるように,cleverhans はニューラルネットワークライブラリとして Tensorflow 及び Keras に対応しています.それ以外のライブラリは(たぶん)利用できません.

早速使ってみる

モデルの定義

cleverhans の導入により,モデル定義は元の実装 mnist_plain.py のようなシンプルな形に戻すことができます.Kerasモデルを cleverhans で利用するためには,Kerasモデルをラッパ cleverhans.utils_keras.KerasModelWrapper でラップします.Tensorflowモデルの場合には,このラッピングは必要なく,ほぼシームレスに利用できます.

from cleverhans.utils_keras import KerasModelWrapper
model_keras = CNN_Model(input_shape, n_classes)
model_cleverhans = KerasModelWrapper(model_keras)

敵対的サンプルの生成

敵対的サンプルをどのように計算するかについては,Tensorflow placeholder を用いて,その計算方法をモデルの外側で明示的に与える必要があります.敵対的サンプルの生成には,cleverhans.attacks.[何らかのメソッド].generate を利用します.ここでは,FastGradientMethod を例に取ります.

from cleverhans.attacks import FastGradientMethod
# placeholders
x = tf.placeholder(tf.float32, shape=(None, img_width, img_height, n_channels))
y = tf.placeholder(tf.float32, shape=(None, n_classes))
# method
fgsm = FastGradientMethod(model_cleverhans, sess=session)
fgsm_params = {'eps': 0.25, 'clip_min': 0.0, 'clip_max': 1.0}
# adversarial examples
x_adv = fgsm.generate(x, **fgsm_params)

FastGradientMethod の部分を別の方法に変更すれば,すぐに別の敵対的サンプル生成手法を試すことができます.どのような方法が実装されているかについては,ドキュメント を確認するのが良いかと思います.本記事執筆時点 (21:35 BST, December 2, 2017) では,

などの代表的な手法に加えて,

などの新しい手法も実装されています.

generateの第2引数には,敵対的サンプル生成の方法に応じたパラメータを python dict で指定します.どのようなパラメータを指定するべきかについては,ドキュメント を確認して下さい.

敵対的サンプルを生成するには,通常は,対象とするモデルが完全に見えていて,入力サンプルに対するモデルの勾配が計算できる必要がありますが,対象のモデルが完全にブラックボックスであっても敵対的サンプルを生成することを可能にする,Papernotらの black-box attack も実装されています.実装方法については,該当のチュートリアル を参考にするのが良いかと思います.

敵対的学習を利用せず,単純に敵対的サンプルが得られれば良い,という場合には,(モデルをしっかり学習した後に)placeholder で定義した部分を評価すればOKです.

X_train_adv = sess.run(x_adv, feed_dict={x: X_train})

敵対的学習を実行する

構成した敵対的サンプルを用いて敵対的学習を行う際には,cleverhans.utils_tf.model_train を使います.この model_train を使う際に,敵対的サンプルからモデル予測を得る方法を外から与えなければなりません.Kerasモデルを利用した場合でも,どうやらラッピング前の生のモデルでもどうやら動くようです.

from cleverhans.utils_tf import model_train
# predictions for clean examples
y_preds = model_keras(x)
# predictions for adversarial examples
y_preds_adv = model_keras(x_adv)
# train a model
rng = np.random.RandomState([2017, 12, 2])
train_params = {'nb_epochs': 100, 'batch_size': 100, 'learning_rate': 0.001}
model_train(session, x, y, y_preds, X_train, Y_train, verbose=True,
            predictions_adv=y_preds_adv), args=train_params, save=False, rng=rng)

model_train の引数は,先頭から,入力placeholder,目標placeholder,予測を獲る方法,入力サンプル,目標サンプル,となっています.また,predictions_adv に敵対的サンプルからモデル予測を獲る方法を,args にモデル学習のためのいくつかのパラメータを,それぞれ入れます.argsの中には,必ず以下の3点を入れる必要があります.

  • nb_epochs: 学習エポック数
  • batch_size: バッチサイズ
  • learning_rate: 学習率(Adamが強制的に使われます,他の選択肢はありません.)

学習プロセスを再現するために,乱数のシードを外から指定することができます.model_train の引数 rng として与えます.

rng = np.random.RandomState([2017, 12, 2])

学習モデルの保存

model_train の引数として save=True を指定すると,学習済モデルを保存することができます.このとき,どこに保存するかについての情報を model_train の引数 args の中に含める必要があります.

  • train_dir: モデルを保存するディレクトリ
  • filename: 保存するモデルのファイル名

検証データの利用

学習過程を検証するために検証用データを用いる場合には,どのような検証を行うかの手順を関数として定義して,その関数を model_train の引数 evaluate で与えます.evaluate で指定する関数には,cleverhans.utils.AccuracyReport オブジェクトを含めておき,このオブジェクトに検証データでの精度を保存するようにしておくと,便利かもしれません(必須ではありません).このオブジェクトの実体はソースで確認すると良いかと思います.

検証用の関数の中では,cleverhans.utils_tf.model_eval を利用すると便利かもしれませんが,これも必須ではありません.Tensorflow もしくは Keras ネイティブの関数でも問題なく動きます.model_eval の使い方は,model_train とほぼ同じです.

from cleverhans.utils_tf import model_eval
from cleverhans.utils import AccuracyReport

# object for accuracy report
report = AccuracyReport()
# function for validation
def evaluate():
    # evaluate a model with clean validation examples
    eval_params = {'batch_size': batch_size}
    accuracy = model_eval(sess, x, y, y_preds, X_test, Y_test, args=eval_params)
    print('Validation accuracy on clean examples: %0.4f' % accuracy)
    report.adv_train_clean_eval = accuracy
    # evaluate a model with adversarial validation examples
    accuracy = model_eval(sess, x, y, y_preds_adv, X_test, Y_test, args=eval_params)
    print('Test accuracy on adversarial examples: %0.4f' % accuracy)
    report.adv_train_adv_eval = accuracy

# train a model
model_train(sess, x, y, y_preds, X_train, Y_train,
            predictions_adv=y_preds_adv, evaluate=evaluate,
            args=train_params, save=False, rng=rng)

まとめると

これらを踏まえ,cleverhansを用いた敵対的学習の実装は以下のようになります.ここまでに記載したすべての要素が入っているわけではありませんが,その点はご容赦下さい.

mnist_fgsm_cleverhans.py
from __future__ import absolute_import, division, print_function
import numpy as np
from sklearn import metrics
import tensorflow as tf
from keras import backend as K
from keras.models import Model, Input, Dense, Conv2D, Flatten, MaxPooling2D, Dropout
from keras.backend import tensorflow_backend
from keras.utils.np_utils import to_categorical
from keras.losses import categorical_crossentropy
from keras.optimizers import Adam

from cleverhans.utils_tf import model_train, model_eval
from cleverhans.attacks import FastGradientMethod
from cleverhans.utils_keras import KerasModelWrapper
import logging
from cleverhans.utils import set_log_level

def CNN_Model(input_shape, output_dim):
    # core model
    __x = Input(shape=input_shape)
    __h = Conv2D(filters=32, kernel_size=3, activation='relu')(__x)
    __h = Conv2D(filters=64, kernel_size=3, activation='relu')(__h)
    __h = MaxPooling2D(pool_size=(2, 2))(__h)
#    __h = Dropout(rate=0.25)(__h)
    __h = Flatten()(__h)
    __h = Dense(units=128, activation='relu')(__h)
#    __h = Dropout(rate=0.25)(__h)
    __y = Dense(units=output_dim, activation='softmax')(__h)
    # return
    return Model(__x, __y)

# main
if __name__ == "__main__":
    # GPU configulations
    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    session = tf.Session(config=config)
    tensorflow_backend.set_session(session)

    # random seeds
    np.random.seed(1)
    tf.set_random_seed(1)

    # parameters
    n_classes = 10
    n_channels = 1
    img_width = 28
    img_height = 28

    # load the dataset
    print('Loading the dataset...')
    from keras.datasets import mnist as dataset
    (X_train, Y_train_int), (X_test, Y_test_int) = dataset.load_data()
    X_train = X_train[:, np.newaxis].transpose((0, 2, 3, 1)).astype('float32') / 255.0
    X_test = X_test[:, np.newaxis].transpose((0, 2, 3, 1)).astype('float32') / 255.0
    Y_train = to_categorical(Y_train_int, num_classes=n_classes)
    Y_test = to_categorical(Y_test_int, num_classes=n_classes)

    # training
    print('Train a NN model...')
    ## model definition
    input_shape = (img_width, img_height, n_channels)
    model_keras = CNN_Model(input_shape, n_classes)
    model_cleverhans = KerasModelWrapper(model_keras)
    train_params = {'nb_epochs': 100, 'batch_size': 100, 'learning_rate': 0.001}  # Adam is used by default
    # input Tensorflow placeholders
    x = tf.placeholder(tf.float32, shape=(None, img_width, img_height, n_channels))
    y = tf.placeholder(tf.float32, shape=(None, n_classes))
    ## adversarial examples
    fgsm = FastGradientMethod(model_cleverhans, sess=session)
    fgsm_params = {'eps': 0.25, 'clip_min': 0.0, 'clip_max': 1.0}
    x_adv = fgsm.generate(x, **fgsm_params)
    ## train
    rng = np.random.RandomState([2017, 12, 2])
    set_log_level(logging.DEBUG)
    model_train(session, x, y, model_keras(x), X_train, Y_train, verbose=True,
                predictions_adv=model_keras(x_adv), args=train_params, save=False, rng=rng)

    # test
    Y_train_pred  = model_keras.predict(X_train)
    Y_train_pred = Y_train_pred.argmax(axis=1)
    Y_test_pred  = model_keras.predict(X_test)
    Y_test_pred = Y_test_pred.argmax(axis=1)
    print('Training score for a NN classifier: \t{0}'.format(
        metrics.accuracy_score(Y_train_int, Y_train_pred)))
    print('Test score for a NN classifier: \t{0}'.format(
        metrics.accuracy_score(Y_test_int, Y_test_pred)))
    print('Training classification report for a NN classifier\n{0}\n'.format(
        metrics.classification_report(Y_train_int, Y_train_pred)))
    print('Test classification report for a NN classifier\n{0}\n'.format(
        metrics.classification_report(Y_test_int, Y_test_pred)))

ここでは,Kerasによる実装を示しましたが,Tensorflowでも利用できます.利用方法については チュートリアル を参考にすると良いかと思います.Kerasとそれほど違う印象はありません.

いくつか注意点

cleverhansは initial public release が2016年9月と非常に新しいライブラリで,やりたいことが何でもできるという状況には未だないように思います.

例えば,どうも現時点では,Kerasモデルにdropoutを組み込むことができないようです.上の mnist_fgsm_cleverhans.py で Dropout がコメントアウトされているのはそのためです.このコメントアウトを外すと,以下のように(placeholderに何も入っていないときに出る典型的な)エラーになってしまいます.

python mnist_fgsm_cleverhans.py

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'dropout_1/keras_learning_phase' with dtype bool
     [[Node: dropout_1/keras_learning_phase = Placeholder[dtype=DT_BOOL, shape=<unknown>, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

cleverhans githubページの議論 を確認する限りでは,TensorflowとKerasのdropout実装が異なることに問題の所在があるようです.これは何となく解決できそうな雰囲気がしますので,解決次第記事をアップデートします.

また,KerasModelWrapperを挟むことによって,モデル学習の過程がほとんど見えなくなってしまうようです.cleverhansのソースを見る限りでは,ログが出ないわけがないのですが,Kerasのように,batchごとの進捗がまったく見えなくなり,1 epochが終わったところでシンプルなログが1行出るだけになっています.これも解決次第記事をアップデートしようかと思います.

以上となります.素敵な敵対的サンプル生成ライフをお楽しみ下さい.


  1. この文脈における敵対的学習は,よく知られた generative adversarial networks (GAN) における generator と discriminator との競争的な学習戦略とは異なります.しかし,NIPS2016 Workshop on Adversarial Training の記載にもあるように,敵対的な何か(モデル構造の場合もあれば,サンプルの場合もある)を利用した学習戦略という意味では,両者は同一視することができます. 

  2. 勾配の各要素にどのようなノルムを適用するかによって,いくつかの亜種を構成することができます.Goodfellow+ ICLR2015 の方法は,勾配の各要素に$\ell_\infty$-normを適用したものと見なすことができます.FastGradientMethod.generateを呼び出す際に,ord=1もしくは2を指定することで$\ell_1$-normもしくは$\ell_2$-normを適用した亜種を利用できます. 

33
28
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
33
28