3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

TL;DR

以下では、僕が遭遇したtf2の厄介な挙動(主にマルチプロセスでの並列化まわり)を4つ紹介していきます。
本当は10選くらい書こうと思ってましたが力尽きて4つになりました

python本体が非常に遅いのもあって、モデルを並列に実行したくなる機会も多いと思います。
そんなとき、tensorflow2ではまりがちな罠をまとめました。

素直にpytorchで実装したほうが楽ですけどね...(公式のtorch.multiprocessingもあるし)

罠1: モデルのパラメータが作られるのは初回に計算を実行したとき

tensorflow2では、モデル・レイヤーのcallメソッドが呼ばれたタイミング(= 順伝播の計算をしたとき)でモデルの重みが初期化されます。

# 例1
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers

dense1 = layers.Dense(10)
print(dense1.traninable_varialbles) # まだ重みが作られていない
# []

data = np.zeros([1, 10])
dense1(dummy) # ここで全結合層の重みが作られる
print(dense1.traninable_varialbles) 
# [<tf.Variable 'dense/kernel:0' shape=(10, 10)...]

上の例では全結合層のインスタンスを作ったタイミングでは全結合層のパラメータは作られておらず、
dataに対して計算を行った後で始めて作られたことが分かります。

これは複数のモデルの重みを学習に同期するときに問題となります。

# 例2
class Model(tf.keras.Model):
    def __init__(self, ....):
        ....

model1 = Model(...)
model2 = Model(...)

weight = model1.get_weights() # モデル1のパラメータはまだ作られていないので, weight = []になる
model2.set_weights(weight)

例2のようにモデルのパラメータを初期化することができないのです。
これを解決するためには、モデルのパラメータを初期化するためのダミーのデータをモデルに流してやる必要があります。(例3)

# 例3
class Model(tf.keras.Model):
    def __init__(self, ....):
        ....

model1 = Model(...)
model2 = Model(...)

# ダミーの入力を順伝播させて、モデルの重みを初期化する
dummy_data = np.zeros([...])
model1(dummy_data)
model2(dummy_data)

weight = model1.get_weights() # weight = [<tf.Variable ...]
model2.set_weights(weight)

罠2: マルチプロセスで走らせると死ぬ~その1~

tf2のissueでも報告されているとおり、
tensorflow2のモデルをマルチプロセスで走らせる時、突然プロセスが死んだりすることがあります。
(運良く?死なないこともある)

これは、tensorflow2がモデルの推論、保存、ロード時に自動的に作るスレッドが上手くmultiprocessingで作った子プロセスと胸像んできないため起こるバグです。
以下のように自動的にスレッドを作らないように設定すればこれは解決できます。
tf.config.threading.set_intra_op_parallelism_threads(1)

罠3: マルチプロセスで走らせると死ぬ~その2~(windows限定?)

Tensorflow2をマルチプロセスで走らせると様々な理由で動きません。
これはそのうちの一つ、windows上で起こるバグです。

例4↓のようなスクリプトを用意してやって、

# 例4
import tensorflow as tf
import multiprocessing
class Logger():
    """tf.SummaryWriterのインスタンスを持つ"""
    def __init__(self, logdir='./'):
        self.writer = tf.summary.create_file_writer(logdir)

class Runner:
    def __init__(self, logger):
        self.logger = logger

    @classmethod
    def start_test(cls, logger):
        runner = cls(logger=logger)

if __name__ == '__main__':
    logger = Logger() # 親プロセスでtf.SummaryWriterのインスタンスを作る
    p = multiprocessing.Process(target=Runner.start_test, kwargs=dict(logger=logger)) # 子プロセスにSummaryWriterのインスタンスを渡す
    p.start()
           

シェル上で実行すると、

$$ python test.py

以下のエラーを吐きます。

Traceback (most recent call last):
  File "test.py", line 20, in <module>
    p.start()
  File "C:\tools\Anaconda3\envs\tf2\lib\multiprocessing\process.py", line 112, in start
    self._popen = self._Popen(self)
  File "C:\tools\Anaconda3\envs\tf2\lib\multiprocessing\context.py", line 223, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "C:\tools\Anaconda3\envs\tf2\lib\multiprocessing\context.py", line 322, in _Popen
    return Popen(process_obj)
  File "C:\tools\Anaconda3\envs\tf2\lib\multiprocessing\popen_spawn_win32.py", line 89, in __init__
    reduction.dump(process_obj, to_child)
  File "C:\tools\Anaconda3\envs\tf2\lib\multiprocessing\reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
  File "C:\tools\Anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\framework\ops.py", line 884, in __reduce__
    return convert_to_tensor, (self._numpy(),)
ValueError: Cannot convert a Tensor of dtype resource to a NumPy array.

エラーが出てしまったのは、tf.SummaryWriterのインスタンスを上手く親プロセスから子プロセスへコピーできなかったためです。
SummaryWriterに限らず、tensorflow由来のオブジェクトはこのようにプロセス間での受け渡しをできないものがあるので、なるべく子プロセスでインスタンスを作ったほうがバグが出にくいです。
OSごとにpythonの標準ライブラリmultiprocessingの実装が異なっているため、OSによってこのバグが出たりでなかったりします。

罠4: マルチプロセスで走らせると死ぬ~その3~(GPU使用時)

マルチプロセスでGPUを使う場合にもtensorflowは死にます。やったね!
具体的には、子・親プロセスの両方でtensorflowをインポートすると、CUDAの初期化エラーが出ます。


# 例5
import tensorflow as tf # 親プロセスでtensorflowをインポート
import multiprocessing

class Runner:
    @classmethod
    def start_test(cls, logger):
        import tensorflow

if __name__ == '__main__':
    p = multiprocessing.Process(target=Runner.start_test)
    p.start() # 子プロセスでtensorflowがインポートされ、エラー
           

これをシェルで走らせると、CUDA_ERROR_NOT_INITIALIZEDというエラーでプロセスが死にます。
親プロセスでtensorflowをインポートするのを避け、子プロセスだけでインポートすればこの問題は回避できます。
tensorflowのオブジェクトではなく、それを作るファクトリメソッドを子プロセスに渡す、などの工夫が必要となります。

結果的に読みにくいプログラムになっちゃいますけどね!!!

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?