概要
GANとかだと、複数のネットワークがあって、それぞれにBatchNormalizationがある状況に遭遇する。このとき、一部のBatchNormalizationだけ更新したい。それには、tf.get_collectionをscopeを指定して用いれば良い。
詳細
前提知識
tf.layers.batch_normalizationを使う場合の訓練は、
x_norm = tf.layers.batch_normalization(x, training=training)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
のようにして訓練時にtf.get_collectionでupdate_opsを取得します。
例
tf.get_collectionの第2引数にスコープを指定すると特定のスコープのものだけ取得できます。
import tensorflow as tf
with tf.variable_scope("Foo"):
x = tf.zeros([5, 10])
tf.layers.batch_normalization(x, training=True)
with tf.variable_scope("Bar"):
x = tf.zeros([5, 10])
tf.layers.batch_normalization(x, training=True)
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS, "Foo"))
# [<tf.Operation 'Foo/batch_normalization/AssignMovingAvg' type=AssignSub>, <tf.Operation 'Foo/batch_normalization/AssignMovingAvg_1' type=AssignSub>]
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS, "Bar"))
# [<tf.Operation 'Bar/batch_normalization/AssignMovingAvg' type=AssignSub>, <tf.Operation 'Bar/batch_normalization/AssignMovingAvg_1' type=AssignSub>]
確認環境
python3.6.6
tensorflow1.9.0