27
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

posted at

kerasで繰り返し学習するとメモリ使用量が増えちゃう問題を対策した

kerasに一晩学習させてるとこんな感じになっちゃうのを対策しました

症状

kerasで繰り返し学習していたところメインメモリの使用量がじわじわ増えてしまうという問題が生じました。メモリリークしちゃったかな?と思ってtracemallockで調べたところ、tensorflow_core/pytho/util/tf_stack.py内のextract_stack()の返り値として作られているlistが原因だと分かりましたが、そんな子は呼んだ覚えがありません。ログか何かを溜め込んでいるんでしょうか?誰か教えて下さい…
EDysDv0UEAAXcXo.png

発症環境

Ubuntu18.04
(cpu:確認して追記します) / TitanV x 1
tensorflow-gpu 1.14.1 / keras 2.2.4

Windows10
intel core i9-9900k / RTX2080Ti x 1
tensorflow-gpu 2.0.0rc0

Windows10
AMD Ryzen7-1700X / GTX1080Ti x 1
tensorflow-gpu 2.0.0rc0

対策

tensorflowさんに要らないlistだけ捨てるようにお願いする方法が見つからなかったので、Sessionごと消しちゃうことにしました。keras.backend.clear_session()を呼べば、いまあるSessionまるごと消え去れます。
https://www.tensorflow.org/api_docs/python/tf/keras/backend/clear_session
https://keras.io/ja/backend/

以下のコードのようにmodelが要らなくなった時点でkeras.backend.clear_session()するとメモリ消費量じわじわ増加病は直ります。

python
import gc
import tensorflow as tf
from tensorflow import keras
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    print('memory growth:', tf.config.experimental.get_memory_growth(physical_devices[0]))
else:
    print("Not enough GPU hardware devices available")

def study_mnist():
    mnist = keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    inputs = keras.layers.Input(shape=(28, 28))
    x = keras.layers.Flatten()(inputs)
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dropout(0.2)(x)
    x = keras.layers.Dense(10, activation=None)(x)
    predictions = keras.layers.Activation('softmax')(x)

    model = keras.models.Model(inputs=inputs, outputs=predictions)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    modelCheckpoint = keras.callbacks.ModelCheckpoint(filepath='../tmp.h5',
                                                      monitor='val_loss',
                                                      verbose=1,
                                                      save_best_only=True, )
    hist = model.fit(x_train, y_train, validation_split=0.1, epochs=10, verbose=0,
                     callbacks=[modelCheckpoint])

    model = keras.models.load_model('../tmp.h5')
    print('ModelCheckPointのスコア')
    print(model.evaluate(x_test, y_test, verbose=0))
    del model

    keras.backend.clear_session() # ←これです
    gc.collect()

def main():
    for k in range(100):
        study_mnist()

if __name__ == "__main__":
    main()

keras.backend.clear_session()がないとき

keras.backend.clear_session()があるとき

やったね!

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
27
Help us understand the problem. What are the problem?