kerasに一晩学習させてるとこんな感じになっちゃうのを対策しました
症状
kerasで繰り返し学習していたところメインメモリの使用量がじわじわ増えてしまうという問題が生じました。メモリリークしちゃったかな?と思ってtracemallockで調べたところ、tensorflow_core/pytho/util/tf_stack.py内のextract_stack()の返り値として作られているlistが原因だと分かりましたが、そんな子は呼んだ覚えがありません。ログか何かを溜め込んでいるんでしょうか?誰か教えて下さい…
発症環境
Ubuntu18.04
(cpu:確認して追記します) / TitanV x 1
tensorflow-gpu 1.14.1 / keras 2.2.4
Windows10
intel core i9-9900k / RTX2080Ti x 1
tensorflow-gpu 2.0.0rc0
Windows10
AMD Ryzen7-1700X / GTX1080Ti x 1
tensorflow-gpu 2.0.0rc0
対策
tensorflowさんに要らないlistだけ捨てるようにお願いする方法が見つからなかったので、Sessionごと消しちゃうことにしました。keras.backend.clear_session()を呼べば、いまあるSessionまるごと消え去れます。
https://www.tensorflow.org/api_docs/python/tf/keras/backend/clear_session
https://keras.io/ja/backend/
以下のコードのようにmodelが要らなくなった時点でkeras.backend.clear_session()するとメモリ消費量じわじわ増加病は直ります。
import gc
import tensorflow as tf
from tensorflow import keras
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
print('memory growth:', tf.config.experimental.get_memory_growth(physical_devices[0]))
else:
print("Not enough GPU hardware devices available")
def study_mnist():
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
inputs = keras.layers.Input(shape=(28, 28))
x = keras.layers.Flatten()(inputs)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dropout(0.2)(x)
x = keras.layers.Dense(10, activation=None)(x)
predictions = keras.layers.Activation('softmax')(x)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
modelCheckpoint = keras.callbacks.ModelCheckpoint(filepath='../tmp.h5',
monitor='val_loss',
verbose=1,
save_best_only=True, )
hist = model.fit(x_train, y_train, validation_split=0.1, epochs=10, verbose=0,
callbacks=[modelCheckpoint])
model = keras.models.load_model('../tmp.h5')
print('ModelCheckPointのスコア')
print(model.evaluate(x_test, y_test, verbose=0))
del model
keras.backend.clear_session() # ←これです
gc.collect()
def main():
for k in range(100):
study_mnist()
if __name__ == "__main__":
main()
keras.backend.clear_session()がないとき
keras.backend.clear_session()があるとき
やったね!