LoginSignup
100
96

More than 5 years have passed since last update.

TensorFlow学習パラメータのsave, restoreでつまった

Last updated at Posted at 2016-05-27

TensorFlowのsave, restoreで少しつまったのでメモ.

基本的な使い方

Tensorflowの学習パラーメータのsave, restoreにはtf.train.Saverを用います.

保存

保存にはtf.train.Saversaveメソッドを使います.


# 幾つか変数作成
w1 = tf.Variable(..., name="v1")
w2 = tf.Variable(..., name="w2")
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

saver = tf.train.Saver()

...

# 保存
saver.save(sess, "model.ckpt", global_step=100)

上記により,変数w1,w2,b1,b2model.ckpkに保存されます.
その際同時に,chackpointファイル(model作成履歴のようなものが入っている)と,metaデータが保存されます.global_stepの指定は任意ですが,学習を途中で中断しまた後から再開したいときなどに有用です.

tf.train.Saver()に何も引数を指定しない場合は全ての変数が保存されますが,下記のように保存する変数を指定することもできます.


# v1をmy_v1として保存
tf.train.Saver('my_v1': v1)

# v1, v2のみ保存
tf.train.Saver([v1, v2])

読み込み

読み込みにはtf.train.Saverrestoreメソッドを使います.

w1 = tf.Variable(..., name="v1")
w2 = tf.Variable(..., name="w2")
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

with tf.Session() as sess:
  # 変数の読み込み
  saver.restore(sess, "model.ckpt")

上記によりw1, w2, v1, v2に保存された値が読み込まれます.

また,保存されたデータがあるかどうかは,

ckpt = tf.train.get_checkpoint_state('./'):

により,チェックできます(例えば'./'にない場合はNoneが帰ります).

そのため,


with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./'):
    if ckpt: # checkpointがある場合
        last_model = ckpt.model_checkpoint_path # 最後に保存したmodelへのパス
        print "load " + last_model
        saver.restore(sess, last_model) # 変数データの読み込み
        ...

    else: # 保存データがない場合
        init = tf.initialize_all_variables()
        sess.run(init) #変数を初期化して実行

のように,保存データの有無で,変数を読み込むか,新規に学習を始めるかを分岐させることができます.

注意どころ

保存の際のtf.Variableのnameと,読み込み時に宣言するtf.Variableのnameを揃える.

当然変数名ではなく,nameの指定が揃っていなければ,正しく読み込まれません.

保存した変数の数と,saver = tf.train.Saver()を呼ぶまでに宣言する変数数をそろえる.

保存,読み込みともに,tf.train.Saver()が呼ばれる前までに呼ばれた引数が対象になります.
そのため,何も保存時も何も変数が宣言されていない状態でtf.train.Saver()を呼び出すと,ValueError: No variables to saveと怒られます.

読み込んだ変数に関しては初期化は不要

saver.restore(sess, model)により読み込まれた変数は,Initializeをしなくても値が入っている状態です.
そのため,restoreを読んだ後に


init = tf.initialize_all_variables()
sess.run(init) #変数を初期化して実行

などとしてしまうと,読み込んだ値が初期化されてしまいます.

w1,w2は読み込んで,w3は新しく学習させたい,という際にはtf.initialize_variables()を用います.

100
96
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
100
96