TL;DR
以下では、僕が遭遇したtf2の厄介な挙動(主にマルチプロセスでの並列化まわり)を4つ紹介していきます。
本当は10選くらい書こうと思ってましたが力尽きて4つになりました
python本体が非常に遅いのもあって、モデルを並列に実行したくなる機会も多いと思います。
そんなとき、tensorflow2ではまりがちな罠をまとめました。
素直にpytorchで実装したほうが楽ですけどね...(公式のtorch.multiprocessing
もあるし)
罠1: モデルのパラメータが作られるのは初回に計算を実行したとき
tensorflow2では、モデル・レイヤーのcallメソッドが呼ばれたタイミング(= 順伝播の計算をしたとき)でモデルの重みが初期化されます。
# 例1
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
dense1 = layers.Dense(10)
print(dense1.traninable_varialbles) # まだ重みが作られていない
# []
data = np.zeros([1, 10])
dense1(dummy) # ここで全結合層の重みが作られる
print(dense1.traninable_varialbles)
# [<tf.Variable 'dense/kernel:0' shape=(10, 10)...]
上の例では全結合層のインスタンスを作ったタイミングでは全結合層のパラメータは作られておらず、
data
に対して計算を行った後で始めて作られたことが分かります。
これは複数のモデルの重みを学習前に同期するときに問題となります。
# 例2
class Model(tf.keras.Model):
def __init__(self, ....):
....
model1 = Model(...)
model2 = Model(...)
weight = model1.get_weights() # モデル1のパラメータはまだ作られていないので, weight = []になる
model2.set_weights(weight)
例2のようにモデルのパラメータを初期化することができないのです。
これを解決するためには、モデルのパラメータを初期化するためのダミーのデータをモデルに流してやる必要があります。(例3)
# 例3
class Model(tf.keras.Model):
def __init__(self, ....):
....
model1 = Model(...)
model2 = Model(...)
# ダミーの入力を順伝播させて、モデルの重みを初期化する
dummy_data = np.zeros([...])
model1(dummy_data)
model2(dummy_data)
weight = model1.get_weights() # weight = [<tf.Variable ...]
model2.set_weights(weight)
罠2: マルチプロセスで走らせると死ぬ~その1~
tf2のissueでも報告されているとおり、
tensorflow2のモデルをマルチプロセスで走らせる時、突然プロセスが死んだりすることがあります。
(運良く?死なないこともある)
これは、tensorflow2がモデルの推論、保存、ロード時に自動的に作るスレッドが上手くmultiprocessingで作った子プロセスと胸像んできないため起こるバグです。
以下のように自動的にスレッドを作らないように設定すればこれは解決できます。
tf.config.threading.set_intra_op_parallelism_threads(1)
罠3: マルチプロセスで走らせると死ぬ~その2~(windows限定?)
Tensorflow2をマルチプロセスで走らせると様々な理由で動きません。
これはそのうちの一つ、windows上で起こるバグです。
例4↓のようなスクリプトを用意してやって、
# 例4
import tensorflow as tf
import multiprocessing
class Logger():
"""tf.SummaryWriterのインスタンスを持つ"""
def __init__(self, logdir='./'):
self.writer = tf.summary.create_file_writer(logdir)
class Runner:
def __init__(self, logger):
self.logger = logger
@classmethod
def start_test(cls, logger):
runner = cls(logger=logger)
if __name__ == '__main__':
logger = Logger() # 親プロセスでtf.SummaryWriterのインスタンスを作る
p = multiprocessing.Process(target=Runner.start_test, kwargs=dict(logger=logger)) # 子プロセスにSummaryWriterのインスタンスを渡す
p.start()
シェル上で実行すると、
$$ python test.py
以下のエラーを吐きます。
Traceback (most recent call last):
File "test.py", line 20, in <module>
p.start()
File "C:\tools\Anaconda3\envs\tf2\lib\multiprocessing\process.py", line 112, in start
self._popen = self._Popen(self)
File "C:\tools\Anaconda3\envs\tf2\lib\multiprocessing\context.py", line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "C:\tools\Anaconda3\envs\tf2\lib\multiprocessing\context.py", line 322, in _Popen
return Popen(process_obj)
File "C:\tools\Anaconda3\envs\tf2\lib\multiprocessing\popen_spawn_win32.py", line 89, in __init__
reduction.dump(process_obj, to_child)
File "C:\tools\Anaconda3\envs\tf2\lib\multiprocessing\reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
File "C:\tools\Anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\framework\ops.py", line 884, in __reduce__
return convert_to_tensor, (self._numpy(),)
ValueError: Cannot convert a Tensor of dtype resource to a NumPy array.
エラーが出てしまったのは、tf.SummaryWriter
のインスタンスを上手く親プロセスから子プロセスへコピーできなかったためです。
SummaryWriterに限らず、tensorflow由来のオブジェクトはこのようにプロセス間での受け渡しをできないものがあるので、なるべく子プロセスでインスタンスを作ったほうがバグが出にくいです。
OSごとにpythonの標準ライブラリmultiprocessing
の実装が異なっているため、OSによってこのバグが出たりでなかったりします。
罠4: マルチプロセスで走らせると死ぬ~その3~(GPU使用時)
マルチプロセスでGPUを使う場合にもtensorflowは死にます。やったね!
具体的には、子・親プロセスの両方でtensorflowをインポートすると、CUDAの初期化エラーが出ます。
# 例5
import tensorflow as tf # 親プロセスでtensorflowをインポート
import multiprocessing
class Runner:
@classmethod
def start_test(cls, logger):
import tensorflow
if __name__ == '__main__':
p = multiprocessing.Process(target=Runner.start_test)
p.start() # 子プロセスでtensorflowがインポートされ、エラー
これをシェルで走らせると、CUDA_ERROR_NOT_INITIALIZED
というエラーでプロセスが死にます。
親プロセスでtensorflowをインポートするのを避け、子プロセスだけでインポートすればこの問題は回避できます。
tensorflowのオブジェクトではなく、それを作るファクトリメソッドを子プロセスに渡す、などの工夫が必要となります。
結果的に読みにくいプログラムになっちゃいますけどね!!!