Python
TensorFlow

tensorflow.session内の特定のvariableを保存する

More than 1 year has passed since last update.

TensorflowのVariableをファイルに保存するにはtensorflow.train.Saverを使う.Tutorialにあるような方法だとSession内のすべてのVariableが保存される.特定のVariableのみを保存・復元するにはtensorflow.train.Saverの初期化関数に対象としたいVariableの一覧を辞書型で与えれば良い.

これによって複数のファイルからVariableを個別に読み込むことができるようになる.

save.py
import tensorflow as tf

def get_particular_variables(name):
    return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}

def define_variables(var0_value, var1_value, var2_value):
    var0 = tf.Variable([var0_value])
    with tf.variable_scope('foo'):
        var1 = tf.Variable([var1_value])
    with tf.variable_scope('bar'):
        var2 = tf.Variable([var2_value])

    return var0, var1, var2


sess = tf.InteractiveSession()

# defines variables
var0, var1, var2 = define_variables(0.0, 0.0, 0.0)

# saving only variables whose name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))

# initializing all of variables
sess.run(tf.initialize_all_variables())

print var0.eval(), var1.eval(), var2.eval()

# saving into file
saver.save(sess, './bar_val')
restore.py
import tensorflow as tf

def get_particular_variables(name):
    return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}

def define_variables(var0_value, var1_value, var2_value):
    var0 = tf.Variable([var0_value])
    with tf.variable_scope('foo'):
        var1 = tf.Variable([var1_value])
    with tf.variable_scope('bar'):
        var2 = tf.Variable([var2_value])

    return var0, var1, var2

sess = tf.InteractiveSession()

# defines variables
var0, var1, var2 = define_variables(1.0, 1.0, 1.0)

# restoring only variables whole name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))

# initializing all of variables
sess.run(tf.initialize_all_variables())
print 'before restoring: ', var0.eval(), var1.eval(), var2.eval()

# restoring variable from file
saver.restore(sess, './bar_val')
print 'after restoring only var in foo: ', var0.eval(), var1.eval(), var2.eval()

だたしこの方法では長い名前や,名前空間の階層に注意する必要がある.例えば,

variable name-of-variable
var0 Variable:0
var1 foo/Variable:0
var2 foo/bar/Variable:0
var3 foobar/Variable:0

のような場合,上記のget_particular_variables('foo')を実行するとvar1, var2, var3の3つが返される.このように検索条件によっては余計なvariableが保存されていて復元時に思わぬバグを生む可能性がある.