STL-10のunlabeledの画像10万枚(96x96)をImageDataGeneratorで回してたらメモリ12GB近く使ってパンクしそうになったので対処法を考えました。
環境:Keras=v2.2.4、TensorFlow=v1.8.0、CPU環境
結論だけ見たい方は「解決法」のとこまで飛んでください。
ImageDataGenerator.flow()は入力データを全部float32にキャストしてる?
STL-10はtrain(5000枚), test(8000枚), unlabeled(10万枚)の3種類のデータからなり、それぞれ1つずつの大きなバイナリファイルに固められています。Pythonの実装は詳しくはこちらにあります。
STL-10の実装は本質的なことではないのですが、このバイナリの画像データがuint8のNumpy配列で定義されているのがポイントなのです。uint8は1バイトで0~255の値で示されるので、ピクセルを表すには過不足がなく無駄のないデータ型なのです。
ところが、TensorFlow(Keras)の計算で使われるのはfloat32なので、文字通り32ビット=4バイト使います。メモリにおいておくだけなら、float32だと4倍の無駄があるのです。例えばSTL-10のunlabeledのように、96×96解像度のカラー画像が10万枚あると、uint8では2.6GBのメモリで済みますが、float32では10.3GBも必要になります。つまり、データをメモリに置く場合は、uint8で置いておいてバッチで切り出すときにfloat32に置き換えるのがメモリ効率の良いやり方になります。
ImageDataGeneratorの場合
では、uint8でキャストして代入した画像をImageDataGeneratorに食わせると、勝手にバッチ単位でfloat32に適宜置き換えてくれるかというとそうではないようです。memory_profilerを使って詳しく見ていきます。事前にpip install memory_profiler
などでインストールしておきます。
from keras.layers import Input, Dense, Flatten
from keras.models import Model
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import to_categorical
from keras.optimizers import Adam
from memory_profiler import profile
import numpy as np
@profile
def test_func_imagegen():
input = Input(shape=(28, 28, 1))
x = Flatten()(input)
x = Dense(64, activation="relu")(x)
x = Dense(10, activation="softmax")(x)
model = Model(input, x)
model.compile(Adam(), loss="categorical_crossentropy")
gen = ImageDataGenerator(rescale=1.0/255)
(X_train, y_train), (_, _) = mnist.load_data()
X_train = X_train.reshape(-1, 28, 28, 1).astype(np.uint8)
y_train = to_categorical(y_train).astype(np.uint8)
model.fit_generator(gen.flow(X_train, y_train, batch_size=128), steps_per_epoch=50000/128, epochs=10)
if __name__ == "__main__":
test_func_imagegen()
簡単なMNISTの例ですが、memory_profilerは以下のようになります。memory_profilerはデコレーターとして簡単に使えます。
Line # Mem usage Increment Line Contents
================================================
49 170.4 MiB 170.4 MiB @profile
50 def test_func_imagegen():
51 173.3 MiB 2.8 MiB input = Input(shape=(28, 28, 1))
52 173.8 MiB 0.5 MiB x = Flatten()(input)
53 174.0 MiB 0.2 MiB x = Dense(64, activation="relu")(x)
54 174.1 MiB 0.1 MiB x = Dense(10, activation="softmax")(x)
55 174.1 MiB 0.0 MiB model = Model(input, x)
56 174.3 MiB 0.2 MiB model.compile(Adam(), loss="categorical_crossentropy")
57
58 174.3 MiB 0.0 MiB gen = ImageDataGenerator(rescale=1.0/255)
59
60 221.1 MiB 46.9 MiB (X_train, y_train), (_, _) = mnist.load_data()
61 221.1 MiB 0.0 MiB X_train = X_train.reshape(-1, 28, 28, 1).astype(np.uint8)
62 221.2 MiB 0.0 MiB y_train = to_categorical(y_train).astype(np.uint8)
63
64 419.2 MiB 198.0 MiB model.fit_generator(gen.flow(X_train, y_train, batch_size=128), steps_per_epoch=50000/128, epochs=10)
MNISTをロードしたところで、46.9MiBメモリ(使用量)が増えているのが確認できます。次のreshape, to_categorialをuint8でキャストしても増加はないので、もともとの型が8ビットの型であったのがわかります。つまり、8ビット変数でのMNISTデータは46.9MiBぐらいと言えます。
ところが、次のImageDataGenerator.flowのところでは一気にメモリが198.0MiB増加しています。このflowの返り値はfloat32で返ってきますが、これはImageDataGenerator.flowの内部的にはバッチ単位でfloat32に変換しているのではなく、X_train全体をfloat32でキャストしているのです。先程の46.9MiB×4=187.6MiBでだいたい説明できますよね。
訓練データをfloat32にキャストしてからmodel.fitの場合
これだけだと「本当かよ?モデルがでかいだけじゃねえの?」と信用してもらえなさそうなので、比較例としてMNIST全体をfloat32でキャストしてから(+255で割ってから)、model.fitで訓練させてみます。コードはこちらになります。
@profile
def test_func_fit():
input = Input(shape=(28, 28, 1))
x = Flatten()(input)
x = Dense(64, activation="relu")(x)
x = Dense(10, activation="softmax")(x)
model = Model(input, x)
model.compile(Adam(), loss="categorical_crossentropy")
(X_train, y_train), (_, _) = mnist.load_data()
X_train = (X_train.reshape(-1, 28, 28, 1) / 255.0).astype(np.float32)
y_train = to_categorical(y_train).astype(np.float32)
model.fit(X_train, y_train, batch_size=128, epochs=10)
こちらは対照的にfitさせる前にfloat32にキャストさせました。
Line # Mem usage Increment Line Contents
================================================
33 170.3 MiB 170.3 MiB @profile
34 def test_func_fit():
35 173.1 MiB 2.8 MiB input = Input(shape=(28, 28, 1))
36 173.6 MiB 0.5 MiB x = Flatten()(input)
37 173.8 MiB 0.2 MiB x = Dense(64, activation="relu")(x)
38 173.9 MiB 0.1 MiB x = Dense(10, activation="softmax")(x)
39 173.9 MiB 0.0 MiB model = Model(input, x)
40 174.1 MiB 0.2 MiB model.compile(Adam(), loss="categorical_crossentropy")
41
42 220.9 MiB 46.9 MiB (X_train, y_train), (_, _) = mnist.load_data()
43 355.8 MiB 134.9 MiB X_train = (X_train.reshape(-1, 28, 28, 1) / 255.0).astype(np.float32)
44 358.1 MiB 2.3 MiB y_train = to_categorical(y_train).astype(np.float32)
45
46 374.1 MiB 16.0 MiB model.fit(X_train, y_train, batch_size=128, epochs=10)
この場合は、X_trainを255で割ってfloat32でキャストしている部分で大きく(134.9MiB)メモリを割り当てているのがわかります。逆にmodel.fitの部分では16.0MiBとそこまで使ってはいないですね。
float32にキャストしてからImageDataGeneratorを使う場合
先程はuint8でキャストしてからImageDataGeneratorに食わせましたが、より慎重に調査するために、float32でキャストしてからImageDataGeneratorに食わせます。結果は次のようになりました。
Line # Mem usage Increment Line Contents
================================================
66 170.2 MiB 170.2 MiB @profile
67 def test_func_imagegen_float32():
68 172.9 MiB 2.7 MiB input = Input(shape=(28, 28, 1))
69 173.3 MiB 0.5 MiB x = Flatten()(input)
70 173.6 MiB 0.2 MiB x = Dense(64, activation="relu")(x)
71 173.6 MiB 0.1 MiB x = Dense(10, activation="softmax")(x)
72 173.6 MiB 0.0 MiB model = Model(input, x)
73 173.8 MiB 0.2 MiB model.compile(Adam(), loss="categorical_crossentropy")
74
75 173.8 MiB 0.0 MiB gen = ImageDataGenerator(rescale=1.0/255)
76
77 220.7 MiB 46.9 MiB (X_train, y_train), (_, _) = mnist.load_data()
78 355.5 MiB 134.8 MiB X_train = X_train.reshape(-1, 28, 28, 1).astype(np.float32)
79 357.9 MiB 2.3 MiB y_train = to_categorical(y_train).astype(np.float32)
80
81 378.4 MiB 20.5 MiB model.fit_generator(gen.flow(X_train, y_train, batch_size=128), steps_per_epoch=50000/128, epochs=10)
この場合は、メモリの大きな割当はfloat32にキャストの部分で発生していて、fit_generatorの部分ではほとんど発生しなくなりました。
これで、ImageDataGeneratorにuint8の画像をflowで読ませると、バッチ単位ではなく画像データ全体をfloat32にキャストしているのが理解できたでしょうか?
ここまでのまとめ
ここまでの結論を整理すると以下のとおりです。
- 画像データをメモリ内に確保しておきたいのなら、uint8で置いておいて、バッチにする際にfloat32にキャストするのがよい
- しかし、ImageDataGenerator.flowではuint8で入力された変数をバッチで切り出すタイミングでfloat32に変換せず、全体を最初にfloat32に変換し、その一部をスライスして出力している
- したがって、画像サイズが大きくなるとImageDataGeneratorではかなりメモリの無駄が発生する。余計なスワップやアロケーションが目立ち効率が悪い。
次に解決策を書きます。
解決策:オリジナルのジェネレーターを作ろう
ImageDataGeneratorのどこが問題なのか明白になったので、自分でジェネレーターを作ってしまえばいいのです。
class CustomGenerator:
def flow(self, X, y=None, batch_size=32, shuffle=True):
if not y is None:
assert X.shape[0] == y.shape[0]
n_sample = X.shape[0]
assert batch_size <= n_sample
n_batch = n_sample // batch_size
while True:
indices = np.arange(n_sample)
if shuffle:
np.random.shuffle(indices)
for i in range(n_batch):
current_indices = indices[i*batch_size:(i+1)*batch_size]
X_batch = (X[current_indices] / 255.0).astype(np.float32)
if y is None:
yield X_batch
else:
y_batch = (y[current_indices]).astype(np.float32)
yield X_batch, y_batch
解説していきます。サンプルのシャッフルはインデックスの配列(indices)を作り、インデックスをシャッフルしています。シャッフルしたインデックスをもとに元のデータをスライスすることで、バッチを切り出します。そして最後に(ここポイント)0.0~1.0のスケール変換後、バッチ単位でfloat32にキャストし、返り値として渡してやります。
もちろんこのジェネレーターはmodel.fit_generatorに渡すことができます。メモリープロファイラーを見てみましょう。
@profile
def test_func_customgen():
input = Input(shape=(28, 28, 1))
x = Flatten()(input)
x = Dense(64, activation="relu")(x)
x = Dense(10, activation="softmax")(x)
model = Model(input, x)
model.compile(Adam(), loss="categorical_crossentropy")
gen = CustomGenerator()
(X_train, y_train), (_, _) = mnist.load_data()
X_train = X_train.reshape(-1, 28, 28, 1).astype(np.uint8)
y_train = to_categorical(y_train).astype(np.uint8)
model.fit_generator(gen.flow(X_train, y_train, batch_size=128), steps_per_epoch=50000/128, epochs=10)
最初のImageDataGeneratorの例のImageDataGenerator()をCustomGenerator()に置き換えただけです。
Line # Mem usage Increment Line Contents
================================================
84 170.3 MiB 170.3 MiB @profile
85 def test_func_customgen():
86 173.0 MiB 2.7 MiB input = Input(shape=(28, 28, 1))
87 173.5 MiB 0.5 MiB x = Flatten()(input)
88 173.7 MiB 0.2 MiB x = Dense(64, activation="relu")(x)
89 173.8 MiB 0.1 MiB x = Dense(10, activation="softmax")(x)
90 173.8 MiB 0.0 MiB model = Model(input, x)
91 174.0 MiB 0.2 MiB model.compile(Adam(), loss="categorical_crossentropy")
92
93 174.0 MiB 0.0 MiB gen = CustomGenerator()
94
95 220.7 MiB 46.7 MiB (X_train, y_train), (_, _) = mnist.load_data()
96 220.7 MiB 0.0 MiB X_train = X_train.reshape(-1, 28, 28, 1).astype(np.uint8)
97 220.7 MiB 0.0 MiB y_train = to_categorical(y_train).astype(np.uint8)
98
99 241.2 MiB 20.5 MiB model.fit_generator(gen.flow(X_train, y_train, batch_size=128), steps_per_epoch=50000/128, epochs=10)
やったぜ。
バッチ単位でfloat32にキャストすることで、メモリの使用量を大きく減らすことができました。欠点として、キャストの回数が増えるので若干遅くなるかもしれませんが、この例では目に見えて遅くなるということはありませんでした。
以上です。「メモリ食い過ぎやゴルァ!!」って思ったら試してみてください。