機械学習
DeepLearning
TensorFlow

TensorFlow Saverで保存する世代数を指定する方法 (元旦の悲劇)

More than 1 year has passed since last update.

経緯

2016年から2017年にかけて、年末年始の長期休暇が元旦のみとなってしまったフリーエンジニアの私。普通ならそんな職場はブラックと呼ばれるのでしょうが、そこを笑顔でクリアしなければならないのがフリーランスの悲しいところ。

大晦日の夕方、TensorFlowを用いて開発したアプリケーション起動し、年越しDeepLearningを開始。学習の成果はTensorBoard用のログファイルとtf.train.SaverによるCheckpointファイルを使用して年明けに確認することにしました。

年が明けて2017年1月2日、人の気配が全くしないゴーストタウンになったかのような東京都心部のオフィスで仕事始めとなりましたが...

Checkpointファイルが無い!!
否、予定通りのディレクトリに保存されていますが予定よりも大幅に少ない!!

はい、その通りです。
ディフォルトではCheckpointファイルは5世代のみ保存されることを失念しておりました...

2017年は残念な仕事始めとなりました。

尚、悲劇(失態)が起こったのがたまたま元旦であっただけで、元旦特有の事象ではありません。

TensorFlowのSave機能で保存する世代数を指定する方法

TensorFlowで学習済みモデルを保存する場合は下記のような実装になります。

実装例1
saver = tf.train.Saver()

if step % 1000 == 0:
    saver.save(sess, 'my-model', global_step=step)

global_stepパラメータを指定することでCheckpointファイルに-xxxxのようなSuffixが付加され、上書きを防止するカラクリになっています。

Checkpointファイルの出力例
[u01@tf0101 save]# ls -l
合計 5884
-rw-r--r-- 1 u01 u01 1200272  1月  2 20:01 my-model-1000.ckpt.data-00000-of-00001
-rw-r--r-- 1 u01 u01     280  1月  2 20:01 my-model-1000.ckpt.index
-rw-r--r-- 1 u01 u01 4816308  1月  2 20:01 my-model-1000.ckpt.meta

しかし、この実装では最大5世代しか保存されないため、6世代目が出力されたタイミングで1世代目が削除されてしまいます。

そこで下記のように保存する世代数を指定します。

実装例2
saver = tf.train.Saver(max_to_keep=100)

if step % 1000 == 0:
    saver.save(sess, 'my-model', global_step=step)

tf.train.Saverにmax_to_keepパラメータを設定することで保存する世代数を指定可能です。

世代数を指定せず無限に保存する方法

下記のいずれかの方法で世代管理されず無限に保存されます。

実装例3
saver = tf.train.Saver(max_to_keep=0)
saver = tf.train.Saver(max_to_keep=None)

1世代だけ保存されれば良い場合

最も単純な実装はファイル名を固定してしまう方法ですね。

実装例5
saver = tf.train.Saver()

if step % 1000 == 0:
    saver.save(sess, 'my-model')

ファイル名にStep数やIteration数などを含まず、global_stepパラメータも指定しなければファイル名が固定されるためCheckpointファイルは常に上書きされるため、1世代分しか残りません。

またはmax_to_keep=1を指定しても同様の結果になります。

実装例6
saver = tf.train.Saver(max_to_keep=1)

if step % 1000 == 0:
    saver.save(sess, 'my-model', global_step=step)

この方式では最後のStep数がファイル名に含まれますが、最後のCheckpointファイル以外は削除されるため1世代しか残りません。

その他

max_to_keepパラメータ以外にkeep_checkpoint_every_n_hoursパラメータも用意されています。
実際に試したことはありませんが、指定の時間が経過すると古いCheckpointファイルが削除されるようです。