LoginSignup
38
34

More than 5 years have passed since last update.

KerasのImageDataGeneratorのメモリ消費を約1/4にする方法

Last updated at Posted at 2018-10-13

STL-10のunlabeledの画像10万枚(96x96)をImageDataGeneratorで回してたらメモリ12GB近く使ってパンクしそうになったので対処法を考えました。

環境:Keras=v2.2.4、TensorFlow=v1.8.0、CPU環境

結論だけ見たい方は「解決法」のとこまで飛んでください。

ImageDataGenerator.flow()は入力データを全部float32にキャストしてる?

STL-10はtrain(5000枚), test(8000枚), unlabeled(10万枚)の3種類のデータからなり、それぞれ1つずつの大きなバイナリファイルに固められています。Pythonの実装は詳しくはこちらにあります。

STL-10の実装は本質的なことではないのですが、このバイナリの画像データがuint8のNumpy配列で定義されているのがポイントなのです。uint8は1バイトで0~255の値で示されるので、ピクセルを表すには過不足がなく無駄のないデータ型なのです。

ところが、TensorFlow(Keras)の計算で使われるのはfloat32なので、文字通り32ビット=4バイト使います。メモリにおいておくだけなら、float32だと4倍の無駄があるのです。例えばSTL-10のunlabeledのように、96×96解像度のカラー画像が10万枚あると、uint8では2.6GBのメモリで済みますが、float32では10.3GBも必要になります。つまり、データをメモリに置く場合は、uint8で置いておいてバッチで切り出すときにfloat32に置き換えるのがメモリ効率の良いやり方になります。

ImageDataGeneratorの場合

では、uint8でキャストして代入した画像をImageDataGeneratorに食わせると、勝手にバッチ単位でfloat32に適宜置き換えてくれるかというとそうではないようです。memory_profilerを使って詳しく見ていきます。事前にpip install memory_profilerなどでインストールしておきます。

from keras.layers import Input, Dense, Flatten
from keras.models import Model
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import to_categorical
from keras.optimizers import Adam
from memory_profiler import profile
import numpy as np

@profile
def test_func_imagegen():
    input = Input(shape=(28, 28, 1))
    x = Flatten()(input)
    x = Dense(64, activation="relu")(x)
    x = Dense(10, activation="softmax")(x)
    model = Model(input, x)
    model.compile(Adam(), loss="categorical_crossentropy")

    gen = ImageDataGenerator(rescale=1.0/255)

    (X_train, y_train), (_, _) = mnist.load_data()
    X_train = X_train.reshape(-1, 28, 28, 1).astype(np.uint8)
    y_train = to_categorical(y_train).astype(np.uint8)

    model.fit_generator(gen.flow(X_train, y_train, batch_size=128), steps_per_epoch=50000/128, epochs=10)

if __name__ == "__main__":
    test_func_imagegen()

簡単なMNISTの例ですが、memory_profilerは以下のようになります。memory_profilerはデコレーターとして簡単に使えます。

ImageDataGenerator
Line #    Mem usage    Increment   Line Contents
================================================
    49    170.4 MiB    170.4 MiB   @profile
    50                             def test_func_imagegen():
    51    173.3 MiB      2.8 MiB       input = Input(shape=(28, 28, 1))
    52    173.8 MiB      0.5 MiB       x = Flatten()(input)
    53    174.0 MiB      0.2 MiB       x = Dense(64, activation="relu")(x)
    54    174.1 MiB      0.1 MiB       x = Dense(10, activation="softmax")(x)
    55    174.1 MiB      0.0 MiB       model = Model(input, x)
    56    174.3 MiB      0.2 MiB       model.compile(Adam(), loss="categorical_crossentropy")
    57
    58    174.3 MiB      0.0 MiB       gen = ImageDataGenerator(rescale=1.0/255)
    59
    60    221.1 MiB     46.9 MiB       (X_train, y_train), (_, _) = mnist.load_data()
    61    221.1 MiB      0.0 MiB       X_train = X_train.reshape(-1, 28, 28, 1).astype(np.uint8)
    62    221.2 MiB      0.0 MiB       y_train = to_categorical(y_train).astype(np.uint8)
    63
    64    419.2 MiB    198.0 MiB       model.fit_generator(gen.flow(X_train, y_train, batch_size=128), steps_per_epoch=50000/128, epochs=10)

MNISTをロードしたところで、46.9MiBメモリ(使用量)が増えているのが確認できます。次のreshape, to_categorialをuint8でキャストしても増加はないので、もともとの型が8ビットの型であったのがわかります。つまり、8ビット変数でのMNISTデータは46.9MiBぐらいと言えます。

ところが、次のImageDataGenerator.flowのところでは一気にメモリが198.0MiB増加しています。このflowの返り値はfloat32で返ってきますが、これはImageDataGenerator.flowの内部的にはバッチ単位でfloat32に変換しているのではなく、X_train全体をfloat32でキャストしているのです。先程の46.9MiB×4=187.6MiBでだいたい説明できますよね。

訓練データをfloat32にキャストしてからmodel.fitの場合

これだけだと「本当かよ?モデルがでかいだけじゃねえの?」と信用してもらえなさそうなので、比較例としてMNIST全体をfloat32でキャストしてから(+255で割ってから)、model.fitで訓練させてみます。コードはこちらになります。

@profile
def test_func_fit():
    input = Input(shape=(28, 28, 1))
    x = Flatten()(input)
    x = Dense(64, activation="relu")(x)
    x = Dense(10, activation="softmax")(x)
    model = Model(input, x)
    model.compile(Adam(), loss="categorical_crossentropy")

    (X_train, y_train), (_, _) = mnist.load_data()
    X_train = (X_train.reshape(-1, 28, 28, 1) / 255.0).astype(np.float32)
    y_train = to_categorical(y_train).astype(np.float32)

    model.fit(X_train, y_train, batch_size=128, epochs=10)

こちらは対照的にfitさせる前にfloat32にキャストさせました。

fit
Line #    Mem usage    Increment   Line Contents
================================================
    33    170.3 MiB    170.3 MiB   @profile
    34                             def test_func_fit():
    35    173.1 MiB      2.8 MiB       input = Input(shape=(28, 28, 1))
    36    173.6 MiB      0.5 MiB       x = Flatten()(input)
    37    173.8 MiB      0.2 MiB       x = Dense(64, activation="relu")(x)
    38    173.9 MiB      0.1 MiB       x = Dense(10, activation="softmax")(x)
    39    173.9 MiB      0.0 MiB       model = Model(input, x)
    40    174.1 MiB      0.2 MiB       model.compile(Adam(), loss="categorical_crossentropy")
    41
    42    220.9 MiB     46.9 MiB       (X_train, y_train), (_, _) = mnist.load_data()
    43    355.8 MiB    134.9 MiB       X_train = (X_train.reshape(-1, 28, 28, 1) / 255.0).astype(np.float32)
    44    358.1 MiB      2.3 MiB       y_train = to_categorical(y_train).astype(np.float32)
    45
    46    374.1 MiB     16.0 MiB       model.fit(X_train, y_train, batch_size=128, epochs=10)

この場合は、X_trainを255で割ってfloat32でキャストしている部分で大きく(134.9MiB)メモリを割り当てているのがわかります。逆にmodel.fitの部分では16.0MiBとそこまで使ってはいないですね。

float32にキャストしてからImageDataGeneratorを使う場合

先程はuint8でキャストしてからImageDataGeneratorに食わせましたが、より慎重に調査するために、float32でキャストしてからImageDataGeneratorに食わせます。結果は次のようになりました。

float32→ImageDataGenerator
Line #    Mem usage    Increment   Line Contents
================================================
    66    170.2 MiB    170.2 MiB   @profile
    67                             def test_func_imagegen_float32():
    68    172.9 MiB      2.7 MiB       input = Input(shape=(28, 28, 1))
    69    173.3 MiB      0.5 MiB       x = Flatten()(input)
    70    173.6 MiB      0.2 MiB       x = Dense(64, activation="relu")(x)
    71    173.6 MiB      0.1 MiB       x = Dense(10, activation="softmax")(x)
    72    173.6 MiB      0.0 MiB       model = Model(input, x)
    73    173.8 MiB      0.2 MiB       model.compile(Adam(), loss="categorical_crossentropy")
    74
    75    173.8 MiB      0.0 MiB       gen = ImageDataGenerator(rescale=1.0/255)

    76
    77    220.7 MiB     46.9 MiB       (X_train, y_train), (_, _) = mnist.load_data()
    78    355.5 MiB    134.8 MiB       X_train = X_train.reshape(-1, 28, 28, 1).astype(np.float32)
    79    357.9 MiB      2.3 MiB       y_train = to_categorical(y_train).astype(np.float32)
    80
    81    378.4 MiB     20.5 MiB       model.fit_generator(gen.flow(X_train, y_train, batch_size=128), steps_per_epoch=50000/128, epochs=10)

この場合は、メモリの大きな割当はfloat32にキャストの部分で発生していて、fit_generatorの部分ではほとんど発生しなくなりました

これで、ImageDataGeneratorにuint8の画像をflowで読ませると、バッチ単位ではなく画像データ全体をfloat32にキャストしているのが理解できたでしょうか?

ここまでのまとめ

ここまでの結論を整理すると以下のとおりです。

  • 画像データをメモリ内に確保しておきたいのなら、uint8で置いておいて、バッチにする際にfloat32にキャストするのがよい
  • しかし、ImageDataGenerator.flowではuint8で入力された変数をバッチで切り出すタイミングでfloat32に変換せず、全体を最初にfloat32に変換し、その一部をスライスして出力している
  • したがって、画像サイズが大きくなるとImageDataGeneratorではかなりメモリの無駄が発生する。余計なスワップやアロケーションが目立ち効率が悪い。

次に解決策を書きます。

解決策:オリジナルのジェネレーターを作ろう

ImageDataGeneratorのどこが問題なのか明白になったので、自分でジェネレーターを作ってしまえばいいのです。

カスタムジェネレーター
class CustomGenerator:
    def flow(self, X, y=None, batch_size=32, shuffle=True):
        if not y is None:
            assert X.shape[0] == y.shape[0]
        n_sample = X.shape[0]
        assert batch_size <= n_sample
        n_batch = n_sample // batch_size

        while True:
            indices = np.arange(n_sample)
            if shuffle:
                np.random.shuffle(indices)

            for i in range(n_batch):
                current_indices = indices[i*batch_size:(i+1)*batch_size]
                X_batch = (X[current_indices] / 255.0).astype(np.float32)                
                if y is None:
                    yield X_batch
                else:
                    y_batch = (y[current_indices]).astype(np.float32)
                    yield X_batch, y_batch

解説していきます。サンプルのシャッフルはインデックスの配列(indices)を作り、インデックスをシャッフルしています。シャッフルしたインデックスをもとに元のデータをスライスすることで、バッチを切り出します。そして最後に(ここポイント)0.0~1.0のスケール変換後、バッチ単位でfloat32にキャストし、返り値として渡してやります。

もちろんこのジェネレーターはmodel.fit_generatorに渡すことができます。メモリープロファイラーを見てみましょう。

@profile
def test_func_customgen():
    input = Input(shape=(28, 28, 1))
    x = Flatten()(input)
    x = Dense(64, activation="relu")(x)
    x = Dense(10, activation="softmax")(x)
    model = Model(input, x)
    model.compile(Adam(), loss="categorical_crossentropy")

    gen = CustomGenerator()

    (X_train, y_train), (_, _) = mnist.load_data()
    X_train = X_train.reshape(-1, 28, 28, 1).astype(np.uint8)
    y_train = to_categorical(y_train).astype(np.uint8)

    model.fit_generator(gen.flow(X_train, y_train, batch_size=128), steps_per_epoch=50000/128, epochs=10)

最初のImageDataGeneratorの例のImageDataGenerator()をCustomGenerator()に置き換えただけです。

CustomGenerator
Line #    Mem usage    Increment   Line Contents
================================================
    84    170.3 MiB    170.3 MiB   @profile
    85                             def test_func_customgen():
    86    173.0 MiB      2.7 MiB       input = Input(shape=(28, 28, 1))
    87    173.5 MiB      0.5 MiB       x = Flatten()(input)
    88    173.7 MiB      0.2 MiB       x = Dense(64, activation="relu")(x)
    89    173.8 MiB      0.1 MiB       x = Dense(10, activation="softmax")(x)
    90    173.8 MiB      0.0 MiB       model = Model(input, x)
    91    174.0 MiB      0.2 MiB       model.compile(Adam(), loss="categorical_crossentropy")
    92
    93    174.0 MiB      0.0 MiB       gen = CustomGenerator()
    94
    95    220.7 MiB     46.7 MiB       (X_train, y_train), (_, _) = mnist.load_data()
    96    220.7 MiB      0.0 MiB       X_train = X_train.reshape(-1, 28, 28, 1).astype(np.uint8)
    97    220.7 MiB      0.0 MiB       y_train = to_categorical(y_train).astype(np.uint8)
    98
    99    241.2 MiB     20.5 MiB       model.fit_generator(gen.flow(X_train, y_train, batch_size=128), steps_per_epoch=50000/128, epochs=10)

やったぜ。

バッチ単位でfloat32にキャストすることで、メモリの使用量を大きく減らすことができました。欠点として、キャストの回数が増えるので若干遅くなるかもしれませんが、この例では目に見えて遅くなるということはありませんでした。

以上です。「メモリ食い過ぎやゴルァ!!」って思ったら試してみてください。

38
34
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
38
34