LoginSignup
3
0

More than 5 years have passed since last update.

TensorFlowで強制的にallow_growth = Trueをする

Last updated at Posted at 2017-12-28

TensorFlowは allow_growth を設定しないとGPUのメモリを全部使おうとしてしまうけど、色々困るので強制的に設定するようにしてみた。

おことわり

世の中こういう事をやりたい人はそれなりにいると思うけど、「バグ報告のときとかに問題になるからやめとけ 1みたいな話もあったりするのでその辺は認識した上でどうぞ。

やり方

site-packages ディレクトリ (Minicondaとかの場合、インストール先/lib/python3.6/site-packages/ とからへん)直下の sitecustomize.py (無ければ作る) に、以下のようなコードを入れる。

import importlib.machinery
import sys


class CustomFinder(importlib.machinery.PathFinder):

    def find_spec(self, fullname, path=None, target=None):
        if fullname == 'tensorflow.python.client.session':
            spec = super().find_spec(fullname, path, target)
            loader = CustomLoader(fullname, spec.origin)
            return importlib.machinery.ModuleSpec(fullname, loader)
        return None


class CustomLoader(importlib.machinery.SourceFileLoader):

    def exec_module(self, module):
        r = super().exec_module(module)
        self._patch(module)
        return r

    def _patch(self, module):
        original_init = module.Session.__init__

        def custom_init(self, *args, config=None, **kwargs):
            import tensorflow as tf
            if config is None:
                config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            return original_init(self, *args, config=config, **kwargs)

        module.Session.__init__ = custom_init
        return module


sys.meta_path.insert(0, CustomFinder())

import されたときに __init__ にmonkey patchする感じ。

動作確認

$ python -c 'import tensorflow as tf; print(tf.Session()._config)'
()
gpu_options {
  allow_growth: true
}

  1. …とまでは言ってないけど。 

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0