やりたいこと
計算グラフを作ったときの関数やソースコードを使わずに、計算グラフの一部を複製したい。
方法
tf.contrib.graph_editor.graph_replace()
を使うと、
入力テンソルと出力テンソルさえ与えられれば途中のグラフを複製できます。
複製元と複製先で変数は共有されます。
公式ドキュメント: tf.contrib.graph_editor.graph_replace
サンプルコードと実行結果
import tensorflow as tf
from tensorflow.contrib import graph_editor
tf.logging.set_verbosity(tf.logging.ERROR)
def f():
"""適当に計算グラフを作成して入力テンソルと出力テンソルを返す"""
with tf.variable_scope('f'):
x = tf.placeholder(tf.float32, (None, 64), name='x')
h = x
h = tf.layers.dense(h, 100, name='dense_1')
h = tf.layers.dense(h, 100, name='dense_2')
y = h
return x, y
with tf.Graph().as_default():
# 計算グラフを作成
x, y = f()
# x と互換性のあるテンソル
new_x = tf.placeholder(x.dtype, x.get_shape(), name='x_copy')
# x の代わりに x_copy を使って、x から y への計算をやりなおす
new_y = graph_editor.graph_replace(
y, # 複製元の出力
{x: new_x}, # {複製元の入力: 新しい入力}
src_scope='f', # 複製元の名前スコープ (省略可)
dst_scope='new_f', # 複製先の名前スコープ (省略可)
)
# TensorBoardで確認
tf.summary.FileWriter('/tmp/tensorboard', graph=tf.get_default_graph()).close()
ちゃんと複製されています。
また、f/dense_1
とnew_f/dense_1
、f/dense_2
とnew_f/dense_2
は重みを共有しています。
補足
src_scope
とdst_scope
x
からy
を計算する上で必要なOpが src_scope/foo/bar
という名前であれば dst_scope/foo/bar
として複製されます。
名前が重複した場合は普通のOpと同じく dst_scope/foo/bar_1
のような名前になります。
入力/出力がtf.Variable
の場合
tf.Variable
はtf.Tensor
ではないので、そのまま引数として渡すとエラーになります。
tf.Variable.value()
で出力テンソルを取得してください。
x = tf.placeholder(name='x', shape=(1, 64), dtype=tf.float32) # Tensor
new_x = tf.get_variable(name='new_x', shape=(1, 64), dtype=tf.float32) # Variale
graph_editor.graph_replace(y, {x: new_x}) # ERROR
graph_editor.graph_replace(y, {x: new_x.value()}) # OK
変数も共有せずに完全に複製する
tf.train.export_meta_graph()
とtf.train.import_meta_graph()
を使えば可能です。
meta_graph.export_scoped_meta_graph()
とmeta_graph.import_scoped_meta_graph()
を使えば可能です。
from tensorflow.python.framework import meta_graph
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(export_scope='f')
meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope='new_f')
ただし、複製元のグラフを作成するときに乱数のシード値を指定していた場合は、複製先も同じシード値を使うことになります。ご注意ください。
tf.train.export_meta_graph()
がSaverを作っているのが理由だったようです。