Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
22
Help us understand the problem. What is going on with this article?
@studio_haneya

kerasで繰り返し学習するとメモリ使用量が増えちゃう問題を対策した

More than 1 year has passed since last update.

kerasに一晩学習させてるとこんな感じになっちゃうのを対策しました

症状

kerasで繰り返し学習していたところメインメモリの使用量がじわじわ増えてしまうという問題が生じました。メモリリークしちゃったかな?と思ってtracemallockで調べたところ、tensorflow_core/pytho/util/tf_stack.py内のextract_stack()の返り値として作られているlistが原因だと分かりましたが、そんな子は呼んだ覚えがありません。ログか何かを溜め込んでいるんでしょうか?誰か教えて下さい…
EDysDv0UEAAXcXo.png

発症環境

Ubuntu18.04
(cpu:確認して追記します) / TitanV x 1
tensorflow-gpu 1.14.1 / keras 2.2.4

Windows10
intel core i9-9900k / RTX2080Ti x 1
tensorflow-gpu 2.0.0rc0

Windows10
AMD Ryzen7-1700X / GTX1080Ti x 1
tensorflow-gpu 2.0.0rc0

対策

tensorflowさんに要らないlistだけ捨てるようにお願いする方法が見つからなかったので、Sessionごと消しちゃうことにしました。keras.backend.clear_session()を呼べば、いまあるSessionまるごと消え去れます。
https://www.tensorflow.org/api_docs/python/tf/keras/backend/clear_session
https://keras.io/ja/backend/

以下のコードのようにmodelが要らなくなった時点でkeras.backend.clear_session()するとメモリ消費量じわじわ増加病は直ります。

python
import gc
import tensorflow as tf
from tensorflow import keras
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    print('memory growth:', tf.config.experimental.get_memory_growth(physical_devices[0]))
else:
    print("Not enough GPU hardware devices available")

def study_mnist():
    mnist = keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    inputs = keras.layers.Input(shape=(28, 28))
    x = keras.layers.Flatten()(inputs)
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dropout(0.2)(x)
    x = keras.layers.Dense(10, activation=None)(x)
    predictions = keras.layers.Activation('softmax')(x)

    model = keras.models.Model(inputs=inputs, outputs=predictions)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    modelCheckpoint = keras.callbacks.ModelCheckpoint(filepath='../tmp.h5',
                                                      monitor='val_loss',
                                                      verbose=1,
                                                      save_best_only=True, )
    hist = model.fit(x_train, y_train, validation_split=0.1, epochs=10, verbose=0,
                     callbacks=[modelCheckpoint])

    model = keras.models.load_model('../tmp.h5')
    print('ModelCheckPointのスコア')
    print(model.evaluate(x_test, y_test, verbose=0))
    del model

    keras.backend.clear_session() # ←これです
    gc.collect()

def main():
    for k in range(100):
        study_mnist()

if __name__ == "__main__":
    main()

keras.backend.clear_session()がないとき

keras.backend.clear_session()があるとき

やったね!

22
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
studio_haneya
製造業でデータサイエンティスト的な仕事をやってます

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
22
Help us understand the problem. What is going on with this article?