TensorFlowのsave, restoreで少しつまったのでメモ.
基本的な使い方
Tensorflowの学習パラーメータのsave, restoreにはtf.train.Saverを用います.
保存
保存にはtf.train.Saverのsaveメソッドを使います.
# 幾つか変数作成
w1 = tf.Variable(..., name="v1")
w2 = tf.Variable(..., name="w2")
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
saver = tf.train.Saver()
...
# 保存
saver.save(sess, "model.ckpt", global_step=100)
上記により,変数w1,w2,b1,b2
がmodel.ckpk
に保存されます.
その際同時に,chackpointファイル(model作成履歴のようなものが入っている)と,metaデータが保存されます.global_stepの指定は任意ですが,学習を途中で中断しまた後から再開したいときなどに有用です.
tf.train.Saver()
に何も引数を指定しない場合は全ての変数が保存されますが,下記のように保存する変数を指定することもできます.
# v1をmy_v1として保存
tf.train.Saver('my_v1': v1)
# v1, v2のみ保存
tf.train.Saver([v1, v2])
読み込み
読み込みにはtf.train.Saverのrestoreメソッドを使います.
w1 = tf.Variable(..., name="v1")
w2 = tf.Variable(..., name="w2")
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
with tf.Session() as sess:
# 変数の読み込み
saver.restore(sess, "model.ckpt")
上記によりw1, w2, v1, v2に保存された値が読み込まれます.
また,保存されたデータがあるかどうかは,
ckpt = tf.train.get_checkpoint_state('./'):
により,チェックできます(例えば'./'にない場合はNoneが帰ります).
そのため,
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('./'):
if ckpt: # checkpointがある場合
last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス
print "load " + last_model
saver.restore(sess, last_model) # 変数データの読み込み
...
else: # 保存データがない場合
init = tf.initialize_all_variables()
sess.run(init) #変数を初期化して実行
のように,保存データの有無で,変数を読み込むか,新規に学習を始めるかを分岐させることができます.
注意どころ
保存の際のtf.Variable
のnameと,読み込み時に宣言するtf.Variable
のnameを揃える.
当然変数名ではなく,nameの指定が揃っていなければ,正しく読み込まれません.
保存した変数の数と,saver = tf.train.Saver()
を呼ぶまでに宣言する変数数をそろえる.
保存,読み込みともに,tf.train.Saver()
が呼ばれる前までに呼ばれた引数が対象になります.
そのため,何も保存時も何も変数が宣言されていない状態でtf.train.Saver()
を呼び出すと,ValueError: No variables to save
と怒られます.
読み込んだ変数に関しては初期化は不要
saver.restore(sess, model)
により読み込まれた変数は,Initializeをしなくても値が入っている状態です.
そのため,restoreを読んだ後に
init = tf.initialize_all_variables()
sess.run(init) #変数を初期化して実行
などとしてしまうと,読み込んだ値が初期化されてしまいます.
w1,w2は読み込んで,w3は新しく学習させたい,という際にはtf.initialize_variables()
を用います.