12
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

KerasでVAT(Virtual Adversarial Training)を使ってMNISTをやってみる

Last updated at Posted at 2017-03-11

はじめに

注: [改良版]KerasでVAT(Virtual Adversarial Training)を使ってMNISTをやってみる の方がマシな実装だと思うので参考にしてください。

遅まきながら、VAT(Virtual Adversarial Training)という学習方法を知ったのですが、Kerasでの実装が見つからなかったので実装してみました。

VATは簡単にいうと、「通常の入力X→出力Y」と「なるべく結果が異なるように入力Xに微小なノイズdを入力に加えた入力(X+d)→出力Y'」から「KL-Divergence(Y, Y')」を損失関数に余分に加えて学習をする手法です。

これだけだと何言ってるかわからないと思うので、詳しくは元の論文か、この方の解説をご覧になると良いかと思います。

VATは学習における位置づけとしては「正則化」に近いという話で、DropoutやNoiseを加える代わりになる可能性があります。Dropoutとかのパラメータを調整するのも面倒なので、VATで代用できると嬉しい気がします。

Kerasだとコスト関数や正則化関数に入力Xを使うようにするのが少し厄介なのですが、そこさえなんとかなれば、ChainerやTheanoでの実装があるので移植すればOKです。

Version

  • Python: 3.5.3
  • Keras: 1.2.2
  • Theano: 0.8.2

実装

こんな感じになりました。

ポイントは、

  • VATの計算は損失関数にだけ入っているので、通常の推論には影響しない
  • trainingするときに、教師信号であるy_trueX_trainもくっつけて渡すことで、損失関数に対して入力Xを渡している
  • 損失関数でVATの計算をするときに、入力→出力の計算を再度行うので Containerを使ってその計算を再利用できるようにしている
  • K.gradients()の後にK.stop_gradient()しないと、学習していくうちにLossがnanになるので注意(これにハマった...)

というところかと思います。

keras_mnist_vat.py
# coding: utf8
"""
* VAT: https://arxiv.org/abs/1507.00677

# 参考にしたCode
Original: https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py
VAT: https://github.com/musyoku/vat/blob/master/vat.py

# Result Example
use_dropout=False, use_vat=False: score=0.211949993095, accuracy=0.9877
use_dropout=True, use_vat=False: score=0.238920686956, accuracy=0.9853
use_dropout=False, use_vat=True: score=0.180048364889, accuracy=0.9916
use_dropout=True, use_vat=True: score=0.245401585515, accuracy=0.9901
"""

import numpy as np
from keras.engine.topology import Input, Container
from keras.engine.training import Model

np.random.seed(1337)  # for reproducibility

from keras.datasets import mnist
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.utils import np_utils
from keras import backend as K

SAMPLE_SIZE = 0

batch_size = 128
nb_classes = 10
nb_epoch = 12

# input image dimensions
img_rows, img_cols = 28, 28
# number of convolutional filters to use
nb_filters = 32
# size of pooling area for max pooling
pool_size = (2, 2)
# convolution kernel size
kernel_size = (3, 3)


def main(data, use_dropout, use_vat):
    # the data, shuffled and split between train and test sets
    (X_train, y_train), (X_test, y_test) = data

    if K.image_dim_ordering() == 'th':
        X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
        X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
        input_shape = (1, img_rows, img_cols)
    else:
        X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
        X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
        input_shape = (img_rows, img_cols, 1)

    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255.
    X_test /= 255.

    # convert class vectors to binary class matrices
    y_train = np_utils.to_categorical(y_train, nb_classes)
    y_test = np_utils.to_categorical(y_test, nb_classes)

    if SAMPLE_SIZE:
        X_train = X_train[:SAMPLE_SIZE]
        y_train = y_train[:SAMPLE_SIZE]
        X_test = X_test[:SAMPLE_SIZE]
        y_test = y_test[:SAMPLE_SIZE]

    my_model = MyModel(input_shape, use_dropout).build()
    my_model.training(X_train, y_train, X_test, y_test, use_vat=use_vat)

    score = my_model.model.evaluate(X_test, y_test, verbose=0)
    print("use_dropout=%s, use_vat=%s: score=%s, accuracy=%s" % (use_dropout, use_vat, score[0], score[1]))


class MyModel:
    model = None
    core_layers = None

    def __init__(self, input_shape, use_dropout=True):
        self.input_shape = input_shape
        self.use_dropout = use_dropout

    def build(self):
        input_layer = Input(self.input_shape)
        output_layer = self.core_data_flow(input_layer)
        self.model = Model(input_layer, output_layer)
        return self

    def core_data_flow(self, input_layer):
        x = Convolution2D(nb_filters, kernel_size[0], kernel_size[1], border_mode='valid')(input_layer)
        x = Activation('relu')(x)
        x = Convolution2D(nb_filters, kernel_size[0], kernel_size[1])(x)
        x = Activation('relu')(x)
        x = MaxPooling2D(pool_size=pool_size)(x)
        if self.use_dropout:
            x = Dropout(0.25)(x)

        x = Flatten()(x)
        x = Dense(128, activation="relu")(x)
        if self.use_dropout:
            x = Dropout(0.5)(x)
        x = Dense(nb_classes, activation='softmax')(x)
        self.core_layers = Container(input_layer, x)
        return x

    def training(self, X_train, y_train, X_test, y_test, use_vat=False):
        orig_loss_func = loss_func = K.categorical_crossentropy
        if use_vat:
            # y_true に concat(y_true, x_train) を取る変則的なLossFunctionを使います
            loss_func = self.loss_with_vat_loss(loss_func)
            self.model.compile(loss=loss_func, optimizer='adadelta', metrics=['accuracy'])
            # train, test ともに y, X を横に連結したデータを作ります
            yX_train = np.concatenate((y_train, X_train.reshape((X_train.shape[0], -1))), axis=1)
            yX_test = np.concatenate((y_test, X_test.reshape((X_test.shape[0], -1))), axis=1)

            # 普通に学習
            self.model.fit(X_train, yX_train, batch_size=batch_size, nb_epoch=nb_epoch,
                           verbose=1, validation_data=(X_test, yX_test))
            # 変則的なLossFunctionを与えたので、普通のLossFunctionに変更して再びcompileしないと、evaluate()が失敗します
            self.model.compile(loss=orig_loss_func, optimizer='adadelta', metrics=['accuracy'])
        else:
            self.model.compile(loss=loss_func, optimizer='adadelta', metrics=['accuracy'])
            self.model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch,
                           verbose=1, validation_data=(X_test, y_test))

    def loss_with_vat_loss(self, original_loss_func, eps=1, xi=10, ip=1):
        def with_vat_loss(yX_train, y_pred):
            nb_output_classes = y_pred.shape[1]
            y_true = yX_train[:, :nb_output_classes]

            # VAT
            X_train = yX_train[:, nb_output_classes:].reshape((-1, ) + self.input_shape)
            d = K.random_normal(X_train.shape)

            for _ in range(ip):
                y = self.core_layers(X_train + self.normalize_vector(d) * xi)
                kld = K.sum(self.kld(y_pred, y))
                d = K.stop_gradient(K.gradients(kld, [d])[0])  # stop_gradient is important!!

            y_perturbation = self.core_layers(X_train + self.normalize_vector(d)*eps)
            kld = self.kld(y_pred, y_perturbation)
            return original_loss_func(y_pred, y_true) + kld
        return with_vat_loss

    @staticmethod
    def normalize_vector(x):
        z = K.sum(K.batch_flatten(K.square(x)), axis=1)
        while K.ndim(z) < K.ndim(x):
            z = K.expand_dims(z, dim=-1)
        return x / (K.sqrt(z) + K.epsilon())

    @staticmethod
    def kld(p, q):
        v = p * (K.log(p + K.epsilon()) - K.log(q + K.epsilon()))
        return K.sum(K.batch_flatten(v), axis=1, keepdims=True)


data = mnist.load_data()
main(data, use_dropout=False, use_vat=False)
main(data, use_dropout=True, use_vat=False)
main(data, use_dropout=False, use_vat=True)
main(data, use_dropout=True, use_vat=True)

実験結果

Dropoutあり・なし、VATあり・なしの4パターンで実験してみました。
1 epochの時間は GeForce GTX 1080で、環境変数を

KERAS_BACKEND=theano 
THEANO_FLAGS=device=gpu,floatX=float32,lib.cnmem=1

として実行したときのものです。

Dropout VAT Accuracy 1 epochの時間
使わない 使わない 98.77% 8秒
使う 使わない 98.53% 8秒
使わない 使う 99.16% 18秒
使う 使う 99.01% 19秒

まあ、そこそこ良い結果ですので、だいたい実装的にも問題ないのではないかと思っています。
VATは1batchで計算を2回行うので実行時間も約2倍になっていますね。

さいごに

VATは教師なし学習にも使えるので、応用範囲の広い学習方法です。
いろいろ活用してみたいです。

12
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?