環境
ubuntu 16.04
tensorflow 1.3.0
python 3.5
tensorboardのtf.summary.merge_all
tensorboardを使うとき、複数のsummaryのTensorを実行用にまとめるため、よく下記のようにmerge_allを使う。
# これらは一例
x = tf.constant(3)
y = tf.constant(3)
image = tf.ones([1, 10, 10, 1])
#
# summaryのTensor作成
tf.summary.scalar("x", x)
tf.summary.scalar("y", y)
tf.summary.image("image", image)
merged = tf.summary.merge_all()
print(tf.get_default_graph().get_collection(tf.GraphKeys.SUMMARIES))
# => [<tf.Tensor 'x:0' shape=() dtype=string>, <tf.Tensor 'y:0' shape=() dtype=string>, <tf.Tensor 'image:0' shape=() dtype=string>]
tf.summary.merge_all に含めたくないTensorがあるとき
しかし、上記でimageだけmergedに含めたくない時はどうするか。
例えば画像情報の保存は負荷が高いので、訓練中には実行したくなかったりする。
summaryのTensor数が少ないうちは個別実行でもよいが、数十とかなってくると厳しい。
途中でmerge_allして、後から含めたくないものだけ定義する方法もあるが、あまりきれいではない気がする。
解決策
merge_allの実際の動作は、tf.GraphKeys.SUMMARIESに含まれているTensorをまとめて一つのTensorにするというもの。(引数でどのGraphKeysにするか指定もできる。)
そこで、下記のように、summary定義の際にcollections=[]を指定してみる。
# summaryのTensor作成
tf.summary.scalar("x", x)
tf.summary.scalar("y", y)
image_summary = tf.summary.image("image", image, collections=[]) # 追加先のcollectionをなしにする
merged = tf.summary.merge_all()
print(tf.get_default_graph().get_collection(tf.GraphKeys.SUMMARIES))
# => [<tf.Tensor 'x:0' shape=() dtype=string>, <tf.Tensor 'y:0' shape=() dtype=string>]
このようにすると、collections=[]を指定したTensorだけ、tf.GraphKeys.SUMMARIESに追加されないため、merge_allの対象からも外すことができる。
imageを保存したいタイミングで、別途image_summaryをrunしてあげればよい。
(・・・これもきれいな方法ではない気はするし、正直GraphKeysの理解も乏しいのでもっといい方法がありそうな気はする。。)