LoginSignup
6
8

More than 5 years have passed since last update.

TensorFlow global_step を保持して学習再開

Last updated at Posted at 2017-05-20

global_step を保持しないまま学習を再開、追加学習すると、
全体として何stepしたのかわからないのと、
model.ckpt-xxx (xxxはstep数) がまた0からになってしまうのが嫌だったので。

import os
from datetime import datetime
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('checkpoint_dir', "./models/",
                           """Directory path of model checkpoint save.""")
tf.app.flags.DEFINE_string('checkpoint_name', "model.ckpt",
                           """Checkpoint name.""")
tf.app.flags.DEFINE_integer('max_steps', 1000,
                            """Number of batches to run.""")
def main(argv=None):
    global_step = tf.Variable(0, name='global_step')
    global_step_holder = tf.placeholder(tf.int32)
    global_step_op = global_step.assign(global_step_holder)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if checkpoint:
            print(checkpoint.model_checkpoint_path)
            print("variables were restored.")
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint.model_checkpoint_path)
        else:
            print("variables were initialized.")
            sess.run(tf.global_variables_initializer())

        checkpoint_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.checkpoint_name);

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            for step in range(global_step.eval() + 1, FLAGS.max_steps + 1):

                # do some train operation here

                if step % 10 == 0:
                    sess.run(global_step_op, feed_dict={global_step_holder: step})
                    print("%s: step %d)" % (datetime.now(), step))
                    saver.save(sess, checkpoint_path, global_step=step)
        finally:
            coord.request_stop()
            coord.join(threads)

if __name__ == '__main__':
    tf.app.run()

global_step を 保存できるようにしたら、

sess.run(global_step_op, feed_dict={global_step_holder: step})

として代入するオペレーションを実行してあげることでsaveした時に保持されます。

6
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
8