Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
106
Help us understand the problem. What is going on with this article?
@yu4u

新たなdata augmentation手法mixupを試してみた

More than 3 years have passed since last update.

はじめに

最近arXivに論文が公開されたdata augmentation手法であるmixupが非常にシンプルな手法だったので試してみました。

mixup

mixup1は、2つの訓練サンプルのペアを混合して新たな訓練サンプルを作成するdata augmentation手法の1つです。
具体的には、データとラベルのペア$(X_1, y_1)$, $(X_2, y_2)$から、下記の式により新たな訓練サンプル$(X, y)$を作成します。ここでラベル$y_1, y_2$はone-hot表現のベクトルになっているものとします。$X_1, X_2$は任意のベクトルやテンソルです。

X = \lambda X_1 + (1 - \lambda) X_2 \\
y = \lambda y_1 + (1 - \lambda) y_2

ここで$\lambda \in [0, 1]$は、ベータ分布$Be(\alpha, \alpha)$からのサンプリングにより取得し、$\alpha$はハイパーパラメータとなります。特徴的なのは、データ$X_1, X_2$だけではなく、ラベル$y_1, y_2$も混合してしまう点です。

この定式化の解釈の参考記事:
http://www.inference.vc/mixup-data-dependent-data-augmentation/

実装

ジェネレータとして実装します。
https://github.com/yu4u/mixup-generator

import numpy as np


class MixupGenerator():
    def __init__(self, X_train, y_train, batch_size=32, alpha=0.2, shuffle=True, datagen=None):
        self.X_train = X_train
        self.y_train = y_train
        self.batch_size = batch_size
        self.alpha = alpha
        self.shuffle = shuffle
        self.sample_num = len(X_train)
        self.datagen = datagen

    def __call__(self):
        while True:
            indexes = self.__get_exploration_order()
            itr_num = int(len(indexes) // (self.batch_size * 2))

            for i in range(itr_num):
                batch_ids = indexes[i * self.batch_size * 2:(i + 1) * self.batch_size * 2]
                X, y = self.__data_generation(batch_ids)

                yield X, y

    def __get_exploration_order(self):
        indexes = np.arange(self.sample_num)

        if self.shuffle:
            np.random.shuffle(indexes)

        return indexes

    def __data_generation(self, batch_ids):
        _, h, w, c = self.X_train.shape
        _, class_num = self.y_train.shape
        X1 = self.X_train[batch_ids[:self.batch_size]]
        X2 = self.X_train[batch_ids[self.batch_size:]]
        y1 = self.y_train[batch_ids[:self.batch_size]]
        y2 = self.y_train[batch_ids[self.batch_size:]]
        l = np.random.beta(self.alpha, self.alpha, self.batch_size)
        X_l = l.reshape(self.batch_size, 1, 1, 1)
        y_l = l.reshape(self.batch_size, 1)

        X = X1 * X_l + X2 * (1 - X_l)
        y = y1 * y_l + y2 * (1 - y_l)

        if self.datagen:
            for i in range(self.batch_size):
                X[i] = self.datagen.random_transform(X[i])

        return X, y

training_generator = MixoutGenerator(x_train, y_train)()で、訓練データとラベルの集合を引数としてジェネレータを取得し、x, y = next(training_generator)で学習用のバッチが取得できます。

CIFAR-10データセットを利用してmixupした例です。

ぼんやり2枚の画像がアルファ合成されたような画像が出力されます。

ジェネレータを利用した訓練

例えばKerasであれば、ジェネレータをそのまま学習する関数に渡してあげれば学習することができます。

model.fit_generator(generator=training_generator,
                    steps_per_epoch=x_train.shape[0] // batch_size,
                    validation_data=(x_test, y_test),
                    epochs=epochs, verbose=1,
                    callbacks=callbacks)

Kerasには、画像をランダムにスケーリングしたり、シフトしたりしてくれる便利なImageDataGeneratorがあり、これと組み合わせることもできます。

datagen = ImageDataGenerator(
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True)

training_generator = MixupGenerator(x_train, y_train, datagen=datagen)()

このケースでは、まずmixupされたデータが作成され、その後ImageDataGeneratorによりランダムな変換が加えられます。

実験結果

Kerasのサンプルのcifar10_resnet.pyを少しいじって実験してみました。

mixupを使わないケース:

Test loss: 0.862150103855
Test accuracy: 0.8978

mixupを使うケース($\alpha = 0.2$):

Test loss: 0.510702615929
Test accuracy: 0.9117

1回のみの試行ですが、効果はありそうです。個人的には、こちらの記事で紹介しているRandom Erasingのほうが画像ドメインでは効果がありそうな印象ですが、組み合わせてみるのも面白いかもしれません。学習させられてるネットワークからすると、ランダムに2枚の画像が合成されたり、ランダムに画像の一部が欠落させられたりして、たまったものではないですが…


  1. H. Zhang, M. Cisse, Y. N. Dauphin, and D. Lopez-Paz, "mixup: Beyond Empirical Risk Minimization," in arXiv:1710.09412, 2017. 

106
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
106
Help us understand the problem. What is going on with this article?