Python
TensorFlow

TensorFlowで強制的にallow_growth = Trueをする

TensorFlowは allow_growth を設定しないとGPUのメモリを全部使おうとしてしまうけど、色々困るので強制的に設定するようにしてみた。

おことわり

世の中こういう事をやりたい人はそれなりにいると思うけど、「バグ報告のときとかに問題になるからやめとけ 1みたいな話もあったりするのでその辺は認識した上でどうぞ。

やり方

site-packages ディレクトリ (Minicondaとかの場合、インストール先/lib/python3.6/site-packages/ とからへん)直下の sitecustomize.py (無ければ作る) に、以下のようなコードを入れる。

import importlib.machinery
import sys


class CustomFinder(importlib.machinery.PathFinder):

    def find_spec(self, fullname, path=None, target=None):
        if fullname == 'tensorflow.python.client.session':
            spec = super().find_spec(fullname, path, target)
            loader = CustomLoader(fullname, spec.origin)
            return importlib.machinery.ModuleSpec(fullname, loader)
        return None


class CustomLoader(importlib.machinery.SourceFileLoader):

    def exec_module(self, module):
        r = super().exec_module(module)
        self._patch(module)
        return r

    def _patch(self, module):
        original_init = module.Session.__init__

        def custom_init(self, *args, config=None, **kwargs):
            import tensorflow as tf
            if config is None:
                config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            return original_init(self, *args, config=config, **kwargs)

        module.Session.__init__ = custom_init
        return module


sys.meta_path.insert(0, CustomFinder())

import されたときに __init__ にmonkey patchする感じ。

動作確認

$ python -c 'import tensorflow as tf; print(tf.Session()._config)'
()
gpu_options {
  allow_growth: true
}

  1. …とまでは言ってないけど。