LoginSignup
1
4

More than 5 years have passed since last update.

【TensorFlow】tf.contrib.graph_editorを使って計算グラフの一部を複製する

Last updated at Posted at 2017-09-30

やりたいこと

計算グラフを作ったときの関数やソースコードを使わずに、計算グラフの一部を複製したい。

方法

tf.contrib.graph_editor.graph_replace()を使うと、
入力テンソルと出力テンソルさえ与えられれば途中のグラフを複製できます。
複製元と複製先で変数は共有されます。

公式ドキュメント: tf.contrib.graph_editor.graph_replace

サンプルコードと実行結果

import tensorflow as tf
from tensorflow.contrib import graph_editor

tf.logging.set_verbosity(tf.logging.ERROR)


def f():
    """適当に計算グラフを作成して入力テンソルと出力テンソルを返す"""
    with tf.variable_scope('f'):
        x = tf.placeholder(tf.float32, (None, 64), name='x')
        h = x
        h = tf.layers.dense(h, 100, name='dense_1')
        h = tf.layers.dense(h, 100, name='dense_2')
        y = h
        return x, y


with tf.Graph().as_default():
    # 計算グラフを作成
    x, y = f()

    # x と互換性のあるテンソル
    new_x = tf.placeholder(x.dtype, x.get_shape(), name='x_copy')

    # x の代わりに x_copy を使って、x から y への計算をやりなおす
    new_y = graph_editor.graph_replace(
        y,                  # 複製元の出力
        {x: new_x},         # {複製元の入力: 新しい入力}
        src_scope='f',      # 複製元の名前スコープ (省略可)
        dst_scope='new_f',  # 複製先の名前スコープ (省略可)
    )

    # TensorBoardで確認
    tf.summary.FileWriter('/tmp/tensorboard', graph=tf.get_default_graph()).close()

graph_large_attrs_key=_too_large_attrs&limit_attr_size=1024&run=(3).png

ちゃんと複製されています。
また、f/dense_1new_f/dense_1f/dense_2new_f/dense_2は重みを共有しています。

補足

src_scopedst_scope

xからyを計算する上で必要なOpが src_scope/foo/bar という名前であれば dst_scope/foo/bar として複製されます。
名前が重複した場合は普通のOpと同じく dst_scope/foo/bar_1 のような名前になります。

入力/出力がtf.Variableの場合

tf.Variabletf.Tensorではないので、そのまま引数として渡すとエラーになります。
tf.Variable.value()で出力テンソルを取得してください。

x = tf.placeholder(name='x', shape=(1, 64), dtype=tf.float32)           # Tensor
new_x = tf.get_variable(name='new_x', shape=(1, 64), dtype=tf.float32)  # Variale

graph_editor.graph_replace(y, {x: new_x})          # ERROR
graph_editor.graph_replace(y, {x: new_x.value()})  # OK

変数も共有せずに完全に複製する

tf.train.export_meta_graph()tf.train.import_meta_graph()を使えば可能です。
meta_graph.export_scoped_meta_graph()meta_graph.import_scoped_meta_graph()を使えば可能です。

from tensorflow.python.framework import meta_graph
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(export_scope='f')
meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope='new_f')

ただし、複製元のグラフを作成するときに乱数のシード値を指定していた場合は、複製先も同じシード値を使うことになります。ご注意ください。
tf.train.export_meta_graph()がSaverを作っているのが理由だったようです。

1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4