TensorFlow学習パラメータのsave, restoreでつまった

  • 45
    Like
  • 0
    Comment
More than 1 year has passed since last update.

TensorFlowのsave, restoreで少しつまったのでメモ.

基本的な使い方

Tensorflowの学習パラーメータのsave, restoreにはtf.train.Saverを用います.

保存

保存にはtf.train.Saversaveメソッドを使います.


# 幾つか変数作成
w1 = tf.Variable(..., name="v1")
w2 = tf.Variable(..., name="w2")
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

saver = tf.train.Saver()

...

# 保存
saver.save(sess, "model.ckpt", global_step=100)

上記により,変数w1,w2,b1,b2model.ckpkに保存されます.
その際同時に,chackpointファイル(model作成履歴のようなものが入っている)と,metaデータが保存されます.global_stepの指定は任意ですが,学習を途中で中断しまた後から再開したいときなどに有用です.

tf.train.Saver()に何も引数を指定しない場合は全ての変数が保存されますが,下記のように保存する変数を指定することもできます.


# v1をmy_v1として保存
tf.train.Saver('my_v1': v1)

# v1, v2のみ保存
tf.train.Saver([v1, v2])

読み込み

読み込みにはtf.train.Saverrestoreメソッドを使います.

w1 = tf.Variable(..., name="v1")
w2 = tf.Variable(..., name="w2")
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

with tf.Session() as sess:
  # 変数の読み込み
  saver.restore(sess, "model.ckpt")

上記によりw1, w2, v1, v2に保存された値が読み込まれます.

また,保存されたデータがあるかどうかは,

ckpt = tf.train.get_checkpoint_state('./'):

により,チェックできます(例えば'./'にない場合はNoneが帰ります).

そのため,


with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./'):
    if ckpt: # checkpointがある場合
        last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス
        print "load " + last_model
        saver.restore(sess, last_model) # 変数データの読み込み
        ...

    else: # 保存データがない場合
        init = tf.initialize_all_variables()
        sess.run(init) #変数を初期化して実行

のように,保存データの有無で,変数を読み込むか,新規に学習を始めるかを分岐させることができます.

注意どころ

保存の際のtf.Variableのnameと,読み込み時に宣言するtf.Variableのnameを揃える.

当然変数名ではなく,nameの指定が揃っていなければ,正しく読み込まれません.

保存した変数の数と,saver = tf.train.Saver()を呼ぶまでに宣言する変数数をそろえる.

保存,読み込みともに,tf.train.Saver()が呼ばれる前までに呼ばれた引数が対象になります.
そのため,何も保存時も何も変数が宣言されていない状態でtf.train.Saver()を呼び出すと,ValueError: No variables to saveと怒られます.

読み込んだ変数に関しては初期化は不要

saver.restore(sess, model)により読み込まれた変数は,Initializeをしなくても値が入っている状態です.
そのため,restoreを読んだ後に


init = tf.initialize_all_variables()
sess.run(init) #変数を初期化して実行

などとしてしまうと,読み込んだ値が初期化されてしまいます.

w1,w2は読み込んで,w3は新しく学習させたい,という際にはtf.initialize_variables()を用います.