LoginSignup
6
10

More than 5 years have passed since last update.

TensorflowのSupervisorの使い方

Last updated at Posted at 2017-10-14

tensorflowにはsupervisorという便利なツールがある.

  • チェックポイントを自動で保存,そこから再開するから,1週間かかる学習も安心
  • queueを自動でスタート,大量のtfrecordファイルも読み込める
  • summaryも自動保存でTensorboardもバッチリ

ということで大規模学習に最適なのに,使い方のサンプルがほとんどみつからない.
そこでminimal sampleを作ってみた.

スクリプト

モデルは$y=ax+b$というシンプルな線形回帰.

データ生成

まずはtfrecordファイルを作るデータ生成スクリプト.

generator.py
import tensorflow as tf
import numpy as np
import os.path

# parameters
a = 5
b = -3.2

# input
x_train = np.arange(0, 5, 0.1)
x_test  = np.arange(5, 10, 0.1)

# output
y_train = a * x_train + b + np.random.normal(scale=0.01, size=len(x_train))
y_test  = a * x_test  + b + np.random.normal(scale=0.01, size=len(x_test))



def save_data(x, y, s='train'):
    '''
    save input x and output y to a single tfrecord file
    '''
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    for i in range(len(x)):
        tfrecords_filename = os.path.join('./', '{:03d}_{}.tfrecords'.format(i, s))
        writer = tf.python_io.TFRecordWriter(tfrecords_filename)

        example = tf.train.Example(features=tf.train.Features(feature={
            'x': _bytes_feature(np.array([x[i]]).astype(np.float32).tostring()),
            'y': _bytes_feature(np.array([y[i]]).astype(np.float32).tostring()),
            })
        )
        writer.write(example.SerializeToString())
        writer.close()


save_data(x_train, y_train, 'train')
save_data(x_test, y_test, 'test')

学習・テスト

そして学習.オプションでテストにもなる.

train.py

import argparse
import sys
import time
from datetime import datetime

import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging



def read_and_decode(filename_queue):

    reader = tf.TFRecordReader()
    key, value = reader.read(filename_queue)

    features = tf.parse_single_example(
        value,
        features={
            'x': tf.FixedLenFeature([], tf.string),
            'y': tf.FixedLenFeature([], tf.string),
        })

    x = tf.decode_raw(features['x'], tf.float32)
    x = tf.reshape(x, [1])
    y = tf.decode_raw(features['y'], tf.float32)
    y = tf.reshape(y, [1])

    return x, y


def model(x):
    a = tf.Variable(tf.truncated_normal([1, 1], stddev=0.1))
    b = tf.Variable(tf.constant(0.1, shape=[1]))
    y = a * x + b
    return y, a, b


def load_batch():

    files = gfile.Glob(FLAGS.train_tfrecords)
    if not files:
        raise IOError("Can't find files: " + FLAGS.train_tfrecords)
    print(files)

    filename_queue = tf.train.string_input_producer(files,
                                                    num_epochs=FLAGS.num_epochs,
                                                    shuffle=FLAGS.train) # shuffle only for training

    x, y = read_and_decode(filename_queue)

    x_batch, y_batch = tf.train.batch([x, y], batch_size=FLAGS.batch_size)

    return x_batch, y_batch


def main(_):

    with tf.Graph().as_default():

        # data
        x_batch, y_batch = load_batch()

        # model
        y_conv, a, b = model(x_batch)

        # loss
        rmse = tf.sqrt(tf.losses.mean_squared_error(y_batch, y_conv))
        tf.summary.scalar('rmse', rmse)

        # minimizer
        global_step = tf.Variable(0, trainable=False)
        train_step = tf.train.AdamOptimizer(0.01).minimize(rmse, global_step=global_step)

        # supervisor
        if FLAGS.train:
            save_model_secs = 1 # save checkpoint every one second for training
        else:
            save_model_secs = 0 # do not save checkpoint

        sv = tf.train.Supervisor(logdir=FLAGS.checkpoint_dir,
                                 global_step=global_step,
                                 save_summaries_secs=1, # save summary every one second
                                 save_model_secs=save_model_secs,
                                 )

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True # always useful

        with sv.managed_session(config=config) as sess:

            # training loop
            try:
                step = 0
                while not sv.should_stop():

                    if FLAGS.train:
                        sess.run([train_step])

                    x_input, y_est, y_true = sess.run([x_batch, y_batch, y_conv])
                    print('x={:.3f}, y_est={:.3f}, y_true={:.3f}'.format(x_input[0][0], y_est[0][0], y_true[0][0]))

                    if step % 100 == 0 and FLAGS.train:
                        rmse_value, g_step, a_val, b_val = sess.run([rmse, global_step, a, b])
                        print('global step {:05d} '.format(g_step), end=" ")
                        print('rmse {:010.2f} '.format(rmse_value), end=" ")
                        print(a_val[0][0], end=" ") # 1x1 matrix
                        print(b_val[0])

                    step += 1

            except tf.errors.OutOfRangeError:
                print('Training ends with {} steps'.format(step))

            sv.Stop()

        print('end')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--train', dest='train', action='store_true')
    parser.add_argument('--test', dest='train', action='store_false')
    parser.set_defaults(train=True)

    parser.add_argument(
        '--train_tfrecords',
        type=str,
        default='./*.tfrecords',
        help='regrep pattern of tfrecord files'
    )
    parser.add_argument(
        '--num_epochs',
        type=int,
        default=100,
        help='Number of epochs'
    )
    parser.add_argument(
        '--batch_size',
        type=int,
        default=1,
        help='batch size'
    )
    parser.add_argument(
        '--checkpoint_dir',
        type=str,
        default='./tmp',
        help='directory where checkpoint files are stored'
    )
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

使い方

まずはデータを生成.

*.tfrecordsというファイルがたくさんできる.

データ生成
$ python3 generator.py

では学習.

学習
$ python3 train.py --train_tfrecords=*_train.tfrecords

['./000_train.tfrecords', './001_train.tfrecords', './002_train.tfrecords', './003_train.tfrecords', './004_train.tfrecords', './005_train.tfrecords', './006_train.tfrecords',
...
'./041_train.tfrecords', './042_train.tfrecords', './043_train.tfrecords', './044_train.tfrecords', './045_train.tfrecords', './046_train.tfrecords', './047_train.tfrecords', './048_train.tfrecords', './049_train.tfrecords']
x=3.500, y_est=14.303, y_true=-0.477
global step 00001  rmse 0000013.74  -0.162092 0.09
x=3.700, y_est=15.290, y_true=-0.497
x=1.200, y_est=2.803, y_true=-0.088
x=1.300, y_est=3.311, y_true=-0.087
x=0.200, y_est=-2.200, y_true=0.080
x=3.900, y_est=16.289, y_true=-0.379
x=3.800, y_est=15.809, y_true=-0.334
x=3.000, y_est=11.810, y_true=-0.209
x=2.500, y_est=9.307, y_true=-0.126
x=0.700, y_est=0.296, y_true=0.070
x=4.900, y_est=21.311, y_true=-0.278
x=2.700, y_est=10.295, y_true=-0.057
x=4.300, y_est=18.293, y_true=-0.133
x=1.000, y_est=1.800, y_true=0.108
x=4.000, y_est=16.790, y_true=-0.025
x=2.200, y_est=7.809, y_true=0.093
x=1.700, y_est=5.301, y_true=0.137
x=0.900, y_est=1.300, y_true=0.180
x=2.300, y_est=8.290, y_true=0.173
...
global step 02401  rmse 0000000.02  4.99713 -3.18101
...
x=1.100, y_est=2.313, y_true=2.303
x=3.900, y_est=16.289, y_true=16.301
x=4.900, y_est=21.311, y_true=21.301
x=4.500, y_est=19.292, y_true=19.303
x=0.800, y_est=0.804, y_true=0.804
x=1.800, y_est=5.797, y_true=5.798
x=2.100, y_est=7.292, y_true=7.290
x=3.500, y_est=14.303, y_true=14.272
x=3.700, y_est=15.290, y_true=15.267
x=1.000, y_est=1.800, y_true=1.793
x=4.600, y_est=19.804, y_true=19.760
x=1.600, y_est=4.792, y_true=4.796
x=1.900, y_est=6.305, y_true=6.303
x=1.400, y_est=3.805, y_true=3.815
x=4.000, y_est=16.790, y_true=16.819
x=4.200, y_est=17.814, y_true=17.826
x=3.800, y_est=15.809, y_true=15.831
x=2.700, y_est=10.295, y_true=10.328
x=0.700, y_est=0.296, y_true=0.321
x=2.200, y_est=7.809, y_true=7.815
x=2.000, y_est=6.790, y_true=6.807
Training ends with 2475 steps
end

学習データはシャッフルされてバッチになる.

supervisorがチェックポイントを保存しているので,もう一回学習すると保存済みチェックポイントから学習を再開する.

もう一回学習
$ python3 train.py --train_tfrecords=*_train.tfrecords

['./000_train.tfrecords', './001_train.tfrecords', './002_train.tfrecords', './003_train.tfrecords', './004_train.tfrecords', './005_train.tfrecords', './006_train.tfrecords', 
...
'./044_train.tfrecords', './045_train.tfrecords', './046_train.tfrecords', './047_train.tfrecords', './048_train.tfrecords', './049_train.tfrecords']
x=2.000, y_est=6.790, y_true=6.820
global step 02208  rmse 0000000.02  5.00433 -3.18855
x=0.300, y_est=-1.696, y_true=-1.690
x=3.900, y_est=16.289, y_true=16.302
x=0.600, y_est=-0.178, y_true=-0.201
x=3.300, y_est=13.299, y_true=13.274
x=0.200, y_est=-2.200, y_true=-2.203
x=3.600, y_est=14.800, y_true=14.767
x=0.400, y_est=-1.193, y_true=-1.204
x=0.800, y_est=0.804, y_true=0.798
...
global step 04608  rmse 0000000.02  5.00579 -3.20404
...
x=3.300, y_est=13.299, y_true=13.295
x=3.000, y_est=11.810, y_true=11.809
x=0.500, y_est=-0.684, y_true=-0.688
x=2.600, y_est=9.818, y_true=9.820
x=0.700, y_est=0.296, y_true=0.315
x=3.400, y_est=13.796, y_true=13.822
x=2.900, y_est=11.314, y_true=11.317
x=1.100, y_est=2.313, y_true=2.309
x=3.600, y_est=14.800, y_true=14.801
x=2.800, y_est=10.811, y_true=10.800
x=2.400, y_est=8.811, y_true=8.796
x=1.500, y_est=4.308, y_true=4.298
x=3.800, y_est=15.809, y_true=15.799
Training ends with 2471 steps
end

そしてテスト

テストのバッチはシャッフルしないので(string_input_producerのshuffleがFalse),入力xが先頭から順番に評価される.

そしてテスト
$ python3 train.py --train_tfrecords=*_test.tfrecords --test --num_epochs=1

['./000_test.tfrecords', './001_test.tfrecords', './002_test.tfrecords', './003_test.tfrecords', './004_test.tfrecords', './005_test.tfrecords', './006_test.tfrecords', './007_test.tfrecords', 
...
'./042_test.tfrecords', './043_test.tfrecords', './044_test.tfrecords', './045_test.tfrecords', './046_test.tfrecords', './047_test.tfrecords', './048_test.tfrecords', './049_test.tfrecords']
x=5.100, y_est=22.314, y_true=22.277
x=5.200, y_est=22.800, y_true=22.776
x=5.300, y_est=23.312, y_true=23.276
x=5.400, y_est=23.798, y_true=23.775
x=5.500, y_est=24.306, y_true=24.274
x=5.600, y_est=24.805, y_true=24.774
x=5.700, y_est=25.309, y_true=25.273
x=5.800, y_est=25.792, y_true=25.773
x=5.900, y_est=26.306, y_true=26.272
x=6.000, y_est=26.796, y_true=26.772
x=6.100, y_est=27.303, y_true=27.271
x=6.200, y_est=27.802, y_true=27.771
x=6.300, y_est=28.311, y_true=28.270
x=6.400, y_est=28.785, y_true=28.770
x=6.500, y_est=29.285, y_true=29.269
x=6.600, y_est=29.786, y_true=29.769
x=6.700, y_est=30.304, y_true=30.268
x=6.800, y_est=30.794, y_true=30.768
x=6.900, y_est=31.296, y_true=31.267
x=7.000, y_est=31.792, y_true=31.767
x=7.100, y_est=32.321, y_true=32.266
x=7.200, y_est=32.806, y_true=32.766
x=7.300, y_est=33.288, y_true=33.265
x=7.400, y_est=33.802, y_true=33.765
x=7.500, y_est=34.301, y_true=34.264
x=7.600, y_est=34.788, y_true=34.764
x=7.700, y_est=35.300, y_true=35.263
x=7.800, y_est=35.793, y_true=35.763
x=7.900, y_est=36.291, y_true=36.262
x=8.000, y_est=36.822, y_true=36.761
x=8.100, y_est=37.295, y_true=37.261
x=8.200, y_est=37.801, y_true=37.760
x=8.300, y_est=38.304, y_true=38.260
x=8.400, y_est=38.799, y_true=38.759
x=8.500, y_est=39.311, y_true=39.259
x=8.600, y_est=39.797, y_true=39.758
x=8.700, y_est=40.318, y_true=40.258
x=8.800, y_est=40.795, y_true=40.757
x=8.900, y_est=41.314, y_true=41.257
x=9.000, y_est=41.790, y_true=41.756
x=9.100, y_est=42.297, y_true=42.256
x=9.200, y_est=42.789, y_true=42.755
x=9.300, y_est=43.297, y_true=43.255
x=9.400, y_est=43.803, y_true=43.754
x=9.500, y_est=44.311, y_true=44.254
x=9.600, y_est=44.805, y_true=44.753
x=9.700, y_est=45.296, y_true=45.253
x=9.800, y_est=45.796, y_true=45.752
x=9.900, y_est=46.303, y_true=46.252
Training ends with 49 steps
end
6
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
10