TensorflowのVariableをファイルに保存するにはtensorflow.train.Saverを使う.Tutorialにあるような方法だとSession内のすべてのVariableが保存される.特定のVariableのみを保存・復元するにはtensorflow.train.Saverの初期化関数に対象としたいVariableの一覧を辞書型で与えれば良い.
これによって複数のファイルからVariableを個別に読み込むことができるようになる.
save.py
import tensorflow as tf
def get_particular_variables(name):
return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}
def define_variables(var0_value, var1_value, var2_value):
var0 = tf.Variable([var0_value])
with tf.variable_scope('foo'):
var1 = tf.Variable([var1_value])
with tf.variable_scope('bar'):
var2 = tf.Variable([var2_value])
return var0, var1, var2
sess = tf.InteractiveSession()
# defines variables
var0, var1, var2 = define_variables(0.0, 0.0, 0.0)
# saving only variables whose name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))
# initializing all of variables
sess.run(tf.initialize_all_variables())
print var0.eval(), var1.eval(), var2.eval()
# saving into file
saver.save(sess, './bar_val')
restore.py
import tensorflow as tf
def get_particular_variables(name):
return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}
def define_variables(var0_value, var1_value, var2_value):
var0 = tf.Variable([var0_value])
with tf.variable_scope('foo'):
var1 = tf.Variable([var1_value])
with tf.variable_scope('bar'):
var2 = tf.Variable([var2_value])
return var0, var1, var2
sess = tf.InteractiveSession()
# defines variables
var0, var1, var2 = define_variables(1.0, 1.0, 1.0)
# restoring only variables whole name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))
# initializing all of variables
sess.run(tf.initialize_all_variables())
print 'before restoring: ', var0.eval(), var1.eval(), var2.eval()
# restoring variable from file
saver.restore(sess, './bar_val')
print 'after restoring only var in foo: ', var0.eval(), var1.eval(), var2.eval()
だたしこの方法では長い名前や,名前空間の階層に注意する必要がある.例えば,
variable | name-of-variable |
---|---|
var0 | Variable:0 |
var1 | foo/Variable:0 |
var2 | foo/bar/Variable:0 |
var3 | foobar/Variable:0 |
のような場合,上記のget_particular_variables('foo')を実行するとvar1, var2, var3の3つが返される.このように検索条件によっては余計なvariableが保存されていて復元時に思わぬバグを生む可能性がある.