Help us understand the problem. What is going on with this article?

【TensorFlow】tf.contrib.graph_editorを使って計算グラフの一部を複製する

More than 1 year has passed since last update.

やりたいこと

計算グラフを作ったときの関数やソースコードを使わずに、計算グラフの一部を複製したい。

方法

tf.contrib.graph_editor.graph_replace()を使うと、
入力テンソルと出力テンソルさえ与えられれば途中のグラフを複製できます。
複製元と複製先で変数は共有されます。

公式ドキュメント: tf.contrib.graph_editor.graph_replace

サンプルコードと実行結果

import tensorflow as tf
from tensorflow.contrib import graph_editor

tf.logging.set_verbosity(tf.logging.ERROR)


def f():
    """適当に計算グラフを作成して入力テンソルと出力テンソルを返す"""
    with tf.variable_scope('f'):
        x = tf.placeholder(tf.float32, (None, 64), name='x')
        h = x
        h = tf.layers.dense(h, 100, name='dense_1')
        h = tf.layers.dense(h, 100, name='dense_2')
        y = h
        return x, y


with tf.Graph().as_default():
    # 計算グラフを作成
    x, y = f()

    # x と互換性のあるテンソル
    new_x = tf.placeholder(x.dtype, x.get_shape(), name='x_copy')

    # x の代わりに x_copy を使って、x から y への計算をやりなおす
    new_y = graph_editor.graph_replace(
        y,                  # 複製元の出力
        {x: new_x},         # {複製元の入力: 新しい入力}
        src_scope='f',      # 複製元の名前スコープ (省略可)
        dst_scope='new_f',  # 複製先の名前スコープ (省略可)
    )

    # TensorBoardで確認
    tf.summary.FileWriter('/tmp/tensorboard', graph=tf.get_default_graph()).close()

graph_large_attrs_key=_too_large_attrs&limit_attr_size=1024&run=(3).png

ちゃんと複製されています。
また、f/dense_1new_f/dense_1f/dense_2new_f/dense_2は重みを共有しています。

補足

src_scopedst_scope

xからyを計算する上で必要なOpが src_scope/foo/bar という名前であれば dst_scope/foo/bar として複製されます。
名前が重複した場合は普通のOpと同じく dst_scope/foo/bar_1 のような名前になります。

入力/出力がtf.Variableの場合

tf.Variabletf.Tensorではないので、そのまま引数として渡すとエラーになります。
tf.Variable.value()で出力テンソルを取得してください。

x = tf.placeholder(name='x', shape=(1, 64), dtype=tf.float32)           # Tensor
new_x = tf.get_variable(name='new_x', shape=(1, 64), dtype=tf.float32)  # Variale

graph_editor.graph_replace(y, {x: new_x})          # ERROR
graph_editor.graph_replace(y, {x: new_x.value()})  # OK

変数も共有せずに完全に複製する

tf.train.export_meta_graph()tf.train.import_meta_graph()を使えば可能です。
meta_graph.export_scoped_meta_graph()meta_graph.import_scoped_meta_graph()を使えば可能です。

from tensorflow.python.framework import meta_graph
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(export_scope='f')
meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope='new_f')

ただし、複製元のグラフを作成するときに乱数のシード値を指定していた場合は、複製先も同じシード値を使うことになります。ご注意ください。
tf.train.export_meta_graph()がSaverを作っているのが理由だったようです。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away