100
70

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

tensorflow2.0 + kerasでGPUメモリの使用量を抑える方法

Last updated at Posted at 2019-06-03

解説

tensorflowは普通に使うとGPUメモリを取れるだけ取りにいっちゃいます。大きいモデルを1つだけ学習するならそれでも良いんですが、小さいモデルで学習する場合もガッツリメモリをもっていかれるのがイマイチです。例えばmnistを隠れ層128ノード1層で学習してもこんな感じになっちゃいます。
image.png

こんなに必要なわけないので、最小限だけ使うように設定します。

tensorflow1系

tensorflow1系を使ってた頃はkerasとtensorflowが別packageだったので、keras.backend.set_session()の引数にtf.Session(config=hoge)を渡すことでコントロールしてました。

python
from keras import backend as K
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
K.set_session(sess)

tensorflow2.0.0alpha

kerasがtensorflowのモジュールになったおかげで2.0.0alphaでは1行で書けるようになりました。

#メモリ制限(growth)
import tensorflow as tf
tf.config.gpu.set_per_process_memory_growth(True)

image.png

tensorflow2.0.0beta~2.0.0

tensorflow2.0.0beta以降でtf.config.gpuがなくなり、tf.config.experimental.set_memory_growth()で設定するようになりました。tf.experimental.config.list_physical_devices('GPU')でGPUデバイスのリストを取得し、memory_growthを設定したいデバイスすべてに対し実行してやります。

python
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    for device in physical_devices:
        tf.config.experimental.set_memory_growth(device, True)
        print('{} memory growth: {}'.format(device, tf.config.experimental.get_memory_growth(device)))
else:
    print("Not enough GPU hardware devices available")

tensorflow2.1.0~2.3.0 (2020.9.15 修正)

tf.config.experimental.list_physical_devices()のexperimentalが取れました。それ以外は上記と同じです。tensorflow2.3.0の時点でまだdeprecatedにもなっていなくてexperimentalが入っていても問題なく動きますので旧いバージョンの入ったマシンと共用で使えるようにしたい場合は上記のコードのままでも大丈夫です。

python
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
    for device in physical_devices:
        tf.config.experimental.set_memory_growth(device, True)
        print('{} memory growth: {}'.format(device, tf.config.experimental.get_memory_growth(device)))
else:
    print("Not enough GPU hardware devices available")

以前書いていた間違いについて (2020.9.15修正)

すみません、2020.8.5に書いた追記でtensorflow2.3.0ではset_memory_growthが要らないようだと書いていたんですが間違いでした。これまでと同様、set_memory_growthが必要です。書き方は同じで出来ます。

更新履歴

2019.8.13 tensorflow2.0.0beta用のメモリ制限について追記しました
2019.10.21 tensorflow2.0 stableでも動作を確認しました
2020.2.1 tensorflow2.1.0での動作を確認しました
2020.9.15 tensorflow2.3.0での動作を確認しました

100
70
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
100
70

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?