LoginSignup
3
3

More than 5 years have passed since last update.

tf.get_collectionの引数scopeで特定スコープのOperationだけを抽出

Last updated at Posted at 2018-07-16

概要

GANとかだと、複数のネットワークがあって、それぞれにBatchNormalizationがある状況に遭遇する。このとき、一部のBatchNormalizationだけ更新したい。それには、tf.get_collectionをscopeを指定して用いれば良い。

詳細

前提知識

tf.layers.batch_normalizationを使う場合の訓練は、

x_norm = tf.layers.batch_normalization(x, training=training)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
  train_op = optimizer.minimize(loss)

のようにして訓練時にtf.get_collectionでupdate_opsを取得します。

tf.get_collectionの第2引数にスコープを指定すると特定のスコープのものだけ取得できます。

import tensorflow as tf

with tf.variable_scope("Foo"):
    x = tf.zeros([5, 10])
    tf.layers.batch_normalization(x, training=True)

with tf.variable_scope("Bar"):
    x = tf.zeros([5, 10])
    tf.layers.batch_normalization(x, training=True)

print(tf.get_collection(tf.GraphKeys.UPDATE_OPS, "Foo"))
# [<tf.Operation 'Foo/batch_normalization/AssignMovingAvg' type=AssignSub>, <tf.Operation 'Foo/batch_normalization/AssignMovingAvg_1' type=AssignSub>]


print(tf.get_collection(tf.GraphKeys.UPDATE_OPS, "Bar"))
# [<tf.Operation 'Bar/batch_normalization/AssignMovingAvg' type=AssignSub>, <tf.Operation 'Bar/batch_normalization/AssignMovingAvg_1' type=AssignSub>]

確認環境

python3.6.6
tensorflow1.9.0

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3