LoginSignup
33
30

More than 3 years have passed since last update.

kerasで繰り返し学習するとメモリ使用量が増えちゃう問題を対策した

Posted at

kerasに一晩学習させてるとこんな感じになっちゃうのを対策しました

症状

kerasで繰り返し学習していたところメインメモリの使用量がじわじわ増えてしまうという問題が生じました。メモリリークしちゃったかな?と思ってtracemallockで調べたところ、tensorflow_core/pytho/util/tf_stack.py内のextract_stack()の返り値として作られているlistが原因だと分かりましたが、そんな子は呼んだ覚えがありません。ログか何かを溜め込んでいるんでしょうか?誰か教えて下さい…
EDysDv0UEAAXcXo.png

発症環境

Ubuntu18.04
(cpu:確認して追記します) / TitanV x 1
tensorflow-gpu 1.14.1 / keras 2.2.4

Windows10
intel core i9-9900k / RTX2080Ti x 1
tensorflow-gpu 2.0.0rc0

Windows10
AMD Ryzen7-1700X / GTX1080Ti x 1
tensorflow-gpu 2.0.0rc0

対策

tensorflowさんに要らないlistだけ捨てるようにお願いする方法が見つからなかったので、Sessionごと消しちゃうことにしました。keras.backend.clear_session()を呼べば、いまあるSessionまるごと消え去れます。
https://www.tensorflow.org/api_docs/python/tf/keras/backend/clear_session
https://keras.io/ja/backend/

以下のコードのようにmodelが要らなくなった時点でkeras.backend.clear_session()するとメモリ消費量じわじわ増加病は直ります。

python
import gc
import tensorflow as tf
from tensorflow import keras
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    print('memory growth:', tf.config.experimental.get_memory_growth(physical_devices[0]))
else:
    print("Not enough GPU hardware devices available")

def study_mnist():
    mnist = keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    inputs = keras.layers.Input(shape=(28, 28))
    x = keras.layers.Flatten()(inputs)
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dropout(0.2)(x)
    x = keras.layers.Dense(10, activation=None)(x)
    predictions = keras.layers.Activation('softmax')(x)

    model = keras.models.Model(inputs=inputs, outputs=predictions)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    modelCheckpoint = keras.callbacks.ModelCheckpoint(filepath='../tmp.h5',
                                                      monitor='val_loss',
                                                      verbose=1,
                                                      save_best_only=True, )
    hist = model.fit(x_train, y_train, validation_split=0.1, epochs=10, verbose=0,
                     callbacks=[modelCheckpoint])

    model = keras.models.load_model('../tmp.h5')
    print('ModelCheckPointのスコア')
    print(model.evaluate(x_test, y_test, verbose=0))
    del model

    keras.backend.clear_session() # ←これです
    gc.collect()

def main():
    for k in range(100):
        study_mnist()

if __name__ == "__main__":
    main()

keras.backend.clear_session()がないとき

keras.backend.clear_session()があるとき

やったね!

33
30
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
33
30