Tensorflowで乱数要素がないのに、実行するたびに結果が変わるグラフを意図せず書いてしまい、デバッグに時間を要したので原因と対策を記載する。
毎回動作が変わるグラフの例はこちら。
a = tf.Variable(1, name="a")
b = tf.Variable(10, name="b")
place_a = tf.assign(a, b)
place_b = tf.assign(b, a)
with tf.Session() as sess:
sess.run(tf.variables_initializer(tf.global_variables()))
print (sess.run([place_a, place_b]))
a,b = 10, 1 と返ってくることを期待してしまうが、実際は[1,1],[10,10]のいずれかの値が返ってくる。
これは、aを更新する計算グラフとbを更新する計算グラフが非同期に(おそらく別スレッドで)実行されているためである。
対策としては、中間変数を作成し、control_dependenciesでa,bが更新される前に確実に中間変数が計算されるようにする。
a = tf.Variable(1, name="a")
b = tf.Variable(10, name="b")
new_a = b + 0
new_b = a + 0
with tf.control_dependencies([new_a, new_b]):
place_a = tf.assign(a, new_a)
place_b = tf.assign(b, new_b)
with tf.Session() as sess:
sess.run(tf.variables_initializer(tf.global_variables()))
print (sess.run([place_a, place_b]))
これで確実にa,b = 10, 1となる。
なお、control_dependenciesを削除すると非決定的な動作になり、うまくいったりいかなかったりする。
new_a = tf.identity(b)としてもうまくいかない。(理由はtf.identityに関する議論を参照)
これぐらい露骨な例まで落とすと何が悪いか気付けるが、複雑なグラフ中だと見落とすリスクがある。
tf.assignの利用は常に注意が必要である。