OptimizerやModelを入力とするtf.functionが"tried to create variables on non-first call"で死ぬ
最初のモデルコール(build)以外で変数を宣言したつもりなんてなかったとおもっていたら思わぬトラップにかかったので備忘録
前提
TensorFlow2(筆者は2.5 nightlyで確認)
tf.function内でのtf.Variableの作成は禁止
- tf.functionは基本的に一度確保したら同じメモリ領域で動作する。
- tf.Variableは新しいメモリ領域を確保して変数を生成する。
- したがって毎回新しいVariableを生成すると毎回新しいメモリを食ってしまうのでエラーがおきる。
@tf.function
def f(x):
v = tf.Variable(1.0)
v.assign_add(x)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception
<class 'ValueError'>: in user code:
<ipython-input-17-73e410646579>:3 f *
v = tf.Variable(1.0)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__ **
return cls._variable_v2_call(*args, **kwargs)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
shape=shape)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
"tf.function-decorated function tried to create "
単なるpython変数の宣言ならよい。関数を呼び出すたびに同じメモリ領域をリサイクルして使われるので。
@tf.function
def f(x):
v = tf.ones((5, 5), dtype=tf.float32)
return v
# raises no error!
本題
tf.keras.models.Modelを入力としたtf.functionは思わぬエラーを引き起こす。
@tf.function
def call(model1: tf.keras.models.Model, inputs: tf.Tensor):
return model1(inputs)
if __name__ == '__main__':
model1 = tf.keras.Sequential([
tf.keras.layers.Dense(16),
tf.keras.layers.Dense(4)
])
model2 = tf.keras.Sequential([
tf.keras.layers.Dense(16),
tf.keras.layers.Dense(4)
])
inputs = tf.ones((10, 10), dtype=tf.float32)
call(model1, inputs) # raises no error!
call(model2, inputs) # raises an error! "tf.function-decorated function tried to create variables on non-first call"
ポイントはkerasのモデルは一番最初にコールされたときに重み行列などのVariableをbuildすること。
よってこのcall関数にビルドされていない2つのモデルを引数として渡すと2度めの呼び出し時に2つめのモデルの重みをビルドしようとして事故る!!
解決策
# @tf.functionは直接つけない
def call(model1: tf.keras.models.Model, inputs: tf.Tensor):
return model1(inputs)
if __name__ == '__main__':
model1 = tf.keras.Sequential([
tf.keras.layers.Dense(16),
tf.keras.layers.Dense(4)
])
model2 = tf.keras.Sequential([
tf.keras.layers.Dense(16),
tf.keras.layers.Dense(4)
])
inputs = tf.ones((10, 10), dtype=tf.float32)
# pythonでは関数もオブジェクト。それぞれのモデル専用の関数を生成する。
model1_call = tf.function(call)
model2_call = tf.function(call)
model1_call(model1, inputs)
model2_call(model2, inputs)
tf.functionは関数の定義に直接書くのをやめて各モデル専用の関数オブジェクトを生成する。デコレータとか高階関数とかってメタプログラミング的な概念はpython初心者には難しいので理由を深追いしないのが大事。
これでいける理由を知りたい人はpython デコレータ 高階関数などでぐぐるとよいとおもわれる。
まとめ
Eager ExecutionはGPU上で複雑かつstep数の多いアルゴリズムを書く場合はおそすぎて話にならないので計算効率を追求するためにtf.functionを使いたいところだがまだTF2がリリースされて間もないことやtf.functionがブラックボックスすぎることが原因で意味不明なバグに躓きがち。
今後もtf2まわりやdeeplearningまわりの知見をたくさん蓄積していくのでLGTMとフォローをお願いします!