DeepLearning
python3
TensorFlow

Tensorflowで転移学習させたモデルを保存する際の注意点

はじめに

Tensorflowで転移学習させたモデルを保存し、再度それを読み込んで処理した時に諸々問題が発生した。

半日ほどハマったので、備忘録を兼ねてメモ。

問題の状況

1 tensorflow.train.saver.restore()を使って学習済みモデルを読み込んだ。この際に再学習させるモデルは若干変えたため、以下のように一部のパラメータだけを読み込んだ。

sess = tf.Session()
......
all_vars = tf.all_variables()
......
sess.run(tf.global_variables_initializer())
......
var_to_restore = []
for num, var1 in enumerate(all_vars):
    _, deter, _ = var1.name.split('/', 2)
    if deter != 'REMOVE_NAME':
        var_to_restore.append(var1)
saver = tf.train.Saver(var_to_restore)
saver.restore(sess, RESTORED_MODEL_NAME)

2 これを学習させ、適当なタイミングでパラメータを保存した。

_ = saver.save(sess, FILE_NAME_OF_CKPT)

3 これを再び読み込んで、再々学習させたり、推論させる際にErrorが出たり、結果がおかしくなったりした。

問題の所在

saver = tf.train.Saver(var_to_restore)

としているので、var_to_restoreに該当するパラメータだけでsaverインスタンスが初期化されている。

つまり全てのパラメータのリストが入っていない。

ちなみにtensorflowの公式サイトの記述ではSaverクラスの初期化は以下のようになっている。

__init__(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

つまり第1引数のvar_listにvar_to_restoreだけ入っている。

この状態で

_ = saver.save(sess, FILE_NAME_OF_CKPT)

と保存すると、その一部のパラメータだけで保存されてしまう。

解決法

そこで以下のように直後にsaverインスタンスを初期化しなおして解決した。

......
saver = tf.train.Saver(var_to_restore)
saver.restore(sess, RESTORED_MODEL_NAME)
saver = tf.train.Saver()

ちなみにTensorflowの公式ドキュメントに

var_list: A list of Variable/SaveableObject, or a dictionary mapping names to SaveableObjects. If None, defaults to the list of all saveable objects.

とあるように、初期化で引数を指定しなければデフォルトでvar_listには全てのパラメータが入る。