Help us understand the problem. What is going on with this article?

新たなdata augmentation手法mixupを試してみた

More than 1 year has passed since last update.

はじめに

最近arXivに論文が公開されたdata augmentation手法であるmixupが非常にシンプルな手法だったので試してみました。

mixup

mixup1は、2つの訓練サンプルのペアを混合して新たな訓練サンプルを作成するdata augmentation手法の1つです。
具体的には、データとラベルのペア$(X_1, y_1)$, $(X_2, y_2)$から、下記の式により新たな訓練サンプル$(X, y)$を作成します。ここでラベル$y_1, y_2$はone-hot表現のベクトルになっているものとします。$X_1, X_2$は任意のベクトルやテンソルです。

X = \lambda X_1 + (1 - \lambda) X_2 \\
y = \lambda y_1 + (1 - \lambda) y_2

ここで$\lambda \in [0, 1]$は、ベータ分布$Be(\alpha, \alpha)$からのサンプリングにより取得し、$\alpha$はハイパーパラメータとなります。特徴的なのは、データ$X_1, X_2$だけではなく、ラベル$y_1, y_2$も混合してしまう点です。

この定式化の解釈の参考記事:
http://www.inference.vc/mixup-data-dependent-data-augmentation/

実装

ジェネレータとして実装します。
https://github.com/yu4u/mixup-generator

import numpy as np


class MixupGenerator():
    def __init__(self, X_train, y_train, batch_size=32, alpha=0.2, shuffle=True, datagen=None):
        self.X_train = X_train
        self.y_train = y_train
        self.batch_size = batch_size
        self.alpha = alpha
        self.shuffle = shuffle
        self.sample_num = len(X_train)
        self.datagen = datagen

    def __call__(self):
        while True:
            indexes = self.__get_exploration_order()
            itr_num = int(len(indexes) // (self.batch_size * 2))

            for i in range(itr_num):
                batch_ids = indexes[i * self.batch_size * 2:(i + 1) * self.batch_size * 2]
                X, y = self.__data_generation(batch_ids)

                yield X, y

    def __get_exploration_order(self):
        indexes = np.arange(self.sample_num)

        if self.shuffle:
            np.random.shuffle(indexes)

        return indexes

    def __data_generation(self, batch_ids):
        _, h, w, c = self.X_train.shape
        _, class_num = self.y_train.shape
        X1 = self.X_train[batch_ids[:self.batch_size]]
        X2 = self.X_train[batch_ids[self.batch_size:]]
        y1 = self.y_train[batch_ids[:self.batch_size]]
        y2 = self.y_train[batch_ids[self.batch_size:]]
        l = np.random.beta(self.alpha, self.alpha, self.batch_size)
        X_l = l.reshape(self.batch_size, 1, 1, 1)
        y_l = l.reshape(self.batch_size, 1)

        X = X1 * X_l + X2 * (1 - X_l)
        y = y1 * y_l + y2 * (1 - y_l)

        if self.datagen:
            for i in range(self.batch_size):
                X[i] = self.datagen.random_transform(X[i])

        return X, y

training_generator = MixoutGenerator(x_train, y_train)()で、訓練データとラベルの集合を引数としてジェネレータを取得し、x, y = next(training_generator)で学習用のバッチが取得できます。

CIFAR-10データセットを利用してmixupした例です。

ぼんやり2枚の画像がアルファ合成されたような画像が出力されます。

ジェネレータを利用した訓練

例えばKerasであれば、ジェネレータをそのまま学習する関数に渡してあげれば学習することができます。

model.fit_generator(generator=training_generator,
                    steps_per_epoch=x_train.shape[0] // batch_size,
                    validation_data=(x_test, y_test),
                    epochs=epochs, verbose=1,
                    callbacks=callbacks)

Kerasには、画像をランダムにスケーリングしたり、シフトしたりしてくれる便利なImageDataGeneratorがあり、これと組み合わせることもできます。

datagen = ImageDataGenerator(
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True)

training_generator = MixupGenerator(x_train, y_train, datagen=datagen)()

このケースでは、まずmixupされたデータが作成され、その後ImageDataGeneratorによりランダムな変換が加えられます。

実験結果

Kerasのサンプルのcifar10_resnet.pyを少しいじって実験してみました。

mixupを使わないケース:

Test loss: 0.862150103855
Test accuracy: 0.8978

mixupを使うケース($\alpha = 0.2$):

Test loss: 0.510702615929
Test accuracy: 0.9117

1回のみの試行ですが、効果はありそうです。個人的には、こちらの記事で紹介しているRandom Erasingのほうが画像ドメインでは効果がありそうな印象ですが、組み合わせてみるのも面白いかもしれません。学習させられてるネットワークからすると、ランダムに2枚の画像が合成されたり、ランダムに画像の一部が欠落させられたりして、たまったものではないですが…


  1. H. Zhang, M. Cisse, Y. N. Dauphin, and D. Lopez-Paz, "mixup: Beyond Empirical Risk Minimization," in arXiv:1710.09412, 2017. 

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした