LoginSignup
4
4

More than 5 years have passed since last update.

tensorflow.session内の特定のvariableを保存する

Last updated at Posted at 2016-08-18

TensorflowのVariableをファイルに保存するにはtensorflow.train.Saverを使う.Tutorialにあるような方法だとSession内のすべてのVariableが保存される.特定のVariableのみを保存・復元するにはtensorflow.train.Saverの初期化関数に対象としたいVariableの一覧を辞書型で与えれば良い.

これによって複数のファイルからVariableを個別に読み込むことができるようになる.

save.py
import tensorflow as tf

def get_particular_variables(name):
    return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}

def define_variables(var0_value, var1_value, var2_value):
    var0 = tf.Variable([var0_value])
    with tf.variable_scope('foo'):
        var1 = tf.Variable([var1_value])
    with tf.variable_scope('bar'):
        var2 = tf.Variable([var2_value])

    return var0, var1, var2


sess = tf.InteractiveSession()

# defines variables
var0, var1, var2 = define_variables(0.0, 0.0, 0.0)

# saving only variables whose name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))

# initializing all of variables
sess.run(tf.initialize_all_variables())

print var0.eval(), var1.eval(), var2.eval()

# saving into file
saver.save(sess, './bar_val')
restore.py
import tensorflow as tf

def get_particular_variables(name):
    return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}

def define_variables(var0_value, var1_value, var2_value):
    var0 = tf.Variable([var0_value])
    with tf.variable_scope('foo'):
        var1 = tf.Variable([var1_value])
    with tf.variable_scope('bar'):
        var2 = tf.Variable([var2_value])

    return var0, var1, var2

sess = tf.InteractiveSession()

# defines variables
var0, var1, var2 = define_variables(1.0, 1.0, 1.0)

# restoring only variables whole name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))

# initializing all of variables
sess.run(tf.initialize_all_variables())
print 'before restoring: ', var0.eval(), var1.eval(), var2.eval()

# restoring variable from file
saver.restore(sess, './bar_val')
print 'after restoring only var in foo: ', var0.eval(), var1.eval(), var2.eval()

だたしこの方法では長い名前や,名前空間の階層に注意する必要がある.例えば,

variable name-of-variable
var0 Variable:0
var1 foo/Variable:0
var2 foo/bar/Variable:0
var3 foobar/Variable:0

のような場合,上記のget_particular_variables('foo')を実行するとvar1, var2, var3の3つが返される.このように検索条件によっては余計なvariableが保存されていて復元時に思わぬバグを生む可能性がある.

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4