はじめに
Tensorflowで転移学習させたモデルを保存し、再度それを読み込んで処理した時に諸々問題が発生した。
半日ほどハマったので、備忘録を兼ねてメモ。
問題の状況
1 tensorflow.train.saver.restore()を使って学習済みモデルを読み込んだ。この際に再学習させるモデルは若干変えたため、以下のように一部のパラメータだけを読み込んだ。
sess = tf.Session()
......
all_vars = tf.all_variables()
......
sess.run(tf.global_variables_initializer())
......
var_to_restore = []
for num, var1 in enumerate(all_vars):
    _, deter, _ = var1.name.split('/', 2)
    if deter != 'REMOVE_NAME':
        var_to_restore.append(var1)
saver = tf.train.Saver(var_to_restore)
saver.restore(sess, RESTORED_MODEL_NAME)
2 これを学習させ、適当なタイミングでパラメータを保存した。
_ = saver.save(sess, FILE_NAME_OF_CKPT)
3 これを再び読み込んで、再々学習させたり、推論させる際にErrorが出たり、結果がおかしくなったりした。
問題の所在
saver = tf.train.Saver(var_to_restore)
としているので、var_to_restoreに該当するパラメータだけでsaverインスタンスが初期化されている。
つまり全てのパラメータのリストが入っていない。
ちなみにtensorflowの公式サイトの記述ではSaverクラスの初期化は以下のようになっている。
__init__(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)
つまり第1引数のvar_listにvar_to_restoreだけ入っている。
この状態で
_ = saver.save(sess, FILE_NAME_OF_CKPT)
と保存すると、その一部のパラメータだけで保存されてしまう。
解決法
そこで以下のように直後にsaverインスタンスを初期化しなおして解決した。
......
saver = tf.train.Saver(var_to_restore)
saver.restore(sess, RESTORED_MODEL_NAME)
saver = tf.train.Saver()
ちなみにTensorflowの公式ドキュメントに
var_list: A list of Variable/SaveableObject, or a dictionary mapping names to SaveableObjects. If None, defaults to the list of all saveable objects.
とあるように、初期化で引数を指定しなければデフォルトでvar_listには全てのパラメータが入る。