はじめに
Tensorflowで転移学習させたモデルを保存し、再度それを読み込んで処理した時に諸々問題が発生した。
半日ほどハマったので、備忘録を兼ねてメモ。
問題の状況
1 tensorflow.train.saver.restore()を使って学習済みモデルを読み込んだ。この際に再学習させるモデルは若干変えたため、以下のように一部のパラメータだけを読み込んだ。
sess = tf.Session()
......
all_vars = tf.all_variables()
......
sess.run(tf.global_variables_initializer())
......
var_to_restore = []
for num, var1 in enumerate(all_vars):
_, deter, _ = var1.name.split('/', 2)
if deter != 'REMOVE_NAME':
var_to_restore.append(var1)
saver = tf.train.Saver(var_to_restore)
saver.restore(sess, RESTORED_MODEL_NAME)
2 これを学習させ、適当なタイミングでパラメータを保存した。
_ = saver.save(sess, FILE_NAME_OF_CKPT)
3 これを再び読み込んで、再々学習させたり、推論させる際にErrorが出たり、結果がおかしくなったりした。
問題の所在
saver = tf.train.Saver(var_to_restore)
としているので、var_to_restoreに該当するパラメータだけでsaverインスタンスが初期化されている。
つまり全てのパラメータのリストが入っていない。
ちなみにtensorflowの公式サイトの記述ではSaverクラスの初期化は以下のようになっている。
__init__(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
つまり第1引数のvar_listにvar_to_restoreだけ入っている。
この状態で
_ = saver.save(sess, FILE_NAME_OF_CKPT)
と保存すると、その一部のパラメータだけで保存されてしまう。
解決法
そこで以下のように直後にsaverインスタンスを初期化しなおして解決した。
......
saver = tf.train.Saver(var_to_restore)
saver.restore(sess, RESTORED_MODEL_NAME)
saver = tf.train.Saver()
ちなみにTensorflowの公式ドキュメントに
var_list: A list of Variable/SaveableObject, or a dictionary mapping names to SaveableObjects. If None, defaults to the list of all saveable objects.
とあるように、初期化で引数を指定しなければデフォルトでvar_listには全てのパラメータが入る。