global_step を保持しないまま学習を再開、追加学習すると、
全体として何stepしたのかわからないのと、
model.ckpt-xxx (xxxはstep数) がまた0からになってしまうのが嫌だったので。
import os
from datetime import datetime
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('checkpoint_dir', "./models/",
"""Directory path of model checkpoint save.""")
tf.app.flags.DEFINE_string('checkpoint_name', "model.ckpt",
"""Checkpoint name.""")
tf.app.flags.DEFINE_integer('max_steps', 1000,
"""Number of batches to run.""")
def main(argv=None):
global_step = tf.Variable(0, name='global_step')
global_step_holder = tf.placeholder(tf.int32)
global_step_op = global_step.assign(global_step_holder)
saver = tf.train.Saver()
with tf.Session() as sess:
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if checkpoint:
print(checkpoint.model_checkpoint_path)
print("variables were restored.")
saver = tf.train.Saver()
saver.restore(sess, checkpoint.model_checkpoint_path)
else:
print("variables were initialized.")
sess.run(tf.global_variables_initializer())
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.checkpoint_name);
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for step in range(global_step.eval() + 1, FLAGS.max_steps + 1):
# do some train operation here
if step % 10 == 0:
sess.run(global_step_op, feed_dict={global_step_holder: step})
print("%s: step %d)" % (datetime.now(), step))
saver.save(sess, checkpoint_path, global_step=step)
finally:
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
tf.app.run()
global_step を 保存できるようにしたら、
sess.run(global_step_op, feed_dict={global_step_holder: step})
として代入するオペレーションを実行してあげることでsaveした時に保持されます。