解説
tensorflowは普通に使うとGPUメモリを取れるだけ取りにいっちゃいます。大きいモデルを1つだけ学習するならそれでも良いんですが、小さいモデルで学習する場合もガッツリメモリをもっていかれるのがイマイチです。例えばmnistを隠れ層128ノード1層で学習してもこんな感じになっちゃいます。
こんなに必要なわけないので、最小限だけ使うように設定します。
tensorflow1系
tensorflow1系を使ってた頃はkerasとtensorflowが別packageだったので、keras.backend.set_session()の引数にtf.Session(config=hoge)を渡すことでコントロールしてました。
from keras import backend as K
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
K.set_session(sess)
tensorflow2.0.0alpha
kerasがtensorflowのモジュールになったおかげで2.0.0alphaでは1行で書けるようになりました。
#メモリ制限(growth)
import tensorflow as tf
tf.config.gpu.set_per_process_memory_growth(True)
tensorflow2.0.0beta~2.0.0
tensorflow2.0.0beta以降でtf.config.gpuがなくなり、tf.config.experimental.set_memory_growth()で設定するようになりました。tf.experimental.config.list_physical_devices('GPU')でGPUデバイスのリストを取得し、memory_growthを設定したいデバイスすべてに対し実行してやります。
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True)
print('{} memory growth: {}'.format(device, tf.config.experimental.get_memory_growth(device)))
else:
print("Not enough GPU hardware devices available")
tensorflow2.1.0~2.3.0 (2020.9.15 修正)
tf.config.experimental.list_physical_devices()のexperimentalが取れました。それ以外は上記と同じです。tensorflow2.3.0の時点でまだdeprecatedにもなっていなくてexperimentalが入っていても問題なく動きますので旧いバージョンの入ったマシンと共用で使えるようにしたい場合は上記のコードのままでも大丈夫です。
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True)
print('{} memory growth: {}'.format(device, tf.config.experimental.get_memory_growth(device)))
else:
print("Not enough GPU hardware devices available")
以前書いていた間違いについて (2020.9.15修正)
すみません、2020.8.5に書いた追記でtensorflow2.3.0ではset_memory_growthが要らないようだと書いていたんですが間違いでした。これまでと同様、set_memory_growthが必要です。書き方は同じで出来ます。
更新履歴
2019.8.13 tensorflow2.0.0beta用のメモリ制限について追記しました
2019.10.21 tensorflow2.0 stableでも動作を確認しました
2020.2.1 tensorflow2.1.0での動作を確認しました
2020.9.15 tensorflow2.3.0での動作を確認しました