tensorflowにはsupervisorという便利なツールがある.
- チェックポイントを自動で保存,そこから再開するから,1週間かかる学習も安心
- queueを自動でスタート,大量のtfrecordファイルも読み込める
- summaryも自動保存でTensorboardもバッチリ
ということで大規模学習に最適なのに,使い方のサンプルがほとんどみつからない.
そこでminimal sampleを作ってみた.
スクリプト
モデルは$y=ax+b$というシンプルな線形回帰.
データ生成
まずはtfrecordファイルを作るデータ生成スクリプト.
generator.py
import tensorflow as tf
import numpy as np
import os.path
# parameters
a = 5
b = -3.2
# input
x_train = np.arange(0, 5, 0.1)
x_test = np.arange(5, 10, 0.1)
# output
y_train = a * x_train + b + np.random.normal(scale=0.01, size=len(x_train))
y_test = a * x_test + b + np.random.normal(scale=0.01, size=len(x_test))
def save_data(x, y, s='train'):
'''
save input x and output y to a single tfrecord file
'''
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
for i in range(len(x)):
tfrecords_filename = os.path.join('./', '{:03d}_{}.tfrecords'.format(i, s))
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
example = tf.train.Example(features=tf.train.Features(feature={
'x': _bytes_feature(np.array([x[i]]).astype(np.float32).tostring()),
'y': _bytes_feature(np.array([y[i]]).astype(np.float32).tostring()),
})
)
writer.write(example.SerializeToString())
writer.close()
save_data(x_train, y_train, 'train')
save_data(x_test, y_test, 'test')
学習・テスト
そして学習.オプションでテストにもなる.
train.py
import argparse
import sys
import time
from datetime import datetime
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
key, value = reader.read(filename_queue)
features = tf.parse_single_example(
value,
features={
'x': tf.FixedLenFeature([], tf.string),
'y': tf.FixedLenFeature([], tf.string),
})
x = tf.decode_raw(features['x'], tf.float32)
x = tf.reshape(x, [1])
y = tf.decode_raw(features['y'], tf.float32)
y = tf.reshape(y, [1])
return x, y
def model(x):
a = tf.Variable(tf.truncated_normal([1, 1], stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=[1]))
y = a * x + b
return y, a, b
def load_batch():
files = gfile.Glob(FLAGS.train_tfrecords)
if not files:
raise IOError("Can't find files: " + FLAGS.train_tfrecords)
print(files)
filename_queue = tf.train.string_input_producer(files,
num_epochs=FLAGS.num_epochs,
shuffle=FLAGS.train) # shuffle only for training
x, y = read_and_decode(filename_queue)
x_batch, y_batch = tf.train.batch([x, y], batch_size=FLAGS.batch_size)
return x_batch, y_batch
def main(_):
with tf.Graph().as_default():
# data
x_batch, y_batch = load_batch()
# model
y_conv, a, b = model(x_batch)
# loss
rmse = tf.sqrt(tf.losses.mean_squared_error(y_batch, y_conv))
tf.summary.scalar('rmse', rmse)
# minimizer
global_step = tf.Variable(0, trainable=False)
train_step = tf.train.AdamOptimizer(0.01).minimize(rmse, global_step=global_step)
# supervisor
if FLAGS.train:
save_model_secs = 1 # save checkpoint every one second for training
else:
save_model_secs = 0 # do not save checkpoint
sv = tf.train.Supervisor(logdir=FLAGS.checkpoint_dir,
global_step=global_step,
save_summaries_secs=1, # save summary every one second
save_model_secs=save_model_secs,
)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True # always useful
with sv.managed_session(config=config) as sess:
# training loop
try:
step = 0
while not sv.should_stop():
if FLAGS.train:
sess.run([train_step])
x_input, y_est, y_true = sess.run([x_batch, y_batch, y_conv])
print('x={:.3f}, y_est={:.3f}, y_true={:.3f}'.format(x_input[0][0], y_est[0][0], y_true[0][0]))
if step % 100 == 0 and FLAGS.train:
rmse_value, g_step, a_val, b_val = sess.run([rmse, global_step, a, b])
print('global step {:05d} '.format(g_step), end=" ")
print('rmse {:010.2f} '.format(rmse_value), end=" ")
print(a_val[0][0], end=" ") # 1x1 matrix
print(b_val[0])
step += 1
except tf.errors.OutOfRangeError:
print('Training ends with {} steps'.format(step))
sv.Stop()
print('end')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train', dest='train', action='store_true')
parser.add_argument('--test', dest='train', action='store_false')
parser.set_defaults(train=True)
parser.add_argument(
'--train_tfrecords',
type=str,
default='./*.tfrecords',
help='regrep pattern of tfrecord files'
)
parser.add_argument(
'--num_epochs',
type=int,
default=100,
help='Number of epochs'
)
parser.add_argument(
'--batch_size',
type=int,
default=1,
help='batch size'
)
parser.add_argument(
'--checkpoint_dir',
type=str,
default='./tmp',
help='directory where checkpoint files are stored'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
使い方
まずはデータを生成.
*.tfrecordsというファイルがたくさんできる.
データ生成
$ python3 generator.py
では学習.
学習
$ python3 train.py --train_tfrecords=*_train.tfrecords
['./000_train.tfrecords', './001_train.tfrecords', './002_train.tfrecords', './003_train.tfrecords', './004_train.tfrecords', './005_train.tfrecords', './006_train.tfrecords',
...
'./041_train.tfrecords', './042_train.tfrecords', './043_train.tfrecords', './044_train.tfrecords', './045_train.tfrecords', './046_train.tfrecords', './047_train.tfrecords', './048_train.tfrecords', './049_train.tfrecords']
x=3.500, y_est=14.303, y_true=-0.477
global step 00001 rmse 0000013.74 -0.162092 0.09
x=3.700, y_est=15.290, y_true=-0.497
x=1.200, y_est=2.803, y_true=-0.088
x=1.300, y_est=3.311, y_true=-0.087
x=0.200, y_est=-2.200, y_true=0.080
x=3.900, y_est=16.289, y_true=-0.379
x=3.800, y_est=15.809, y_true=-0.334
x=3.000, y_est=11.810, y_true=-0.209
x=2.500, y_est=9.307, y_true=-0.126
x=0.700, y_est=0.296, y_true=0.070
x=4.900, y_est=21.311, y_true=-0.278
x=2.700, y_est=10.295, y_true=-0.057
x=4.300, y_est=18.293, y_true=-0.133
x=1.000, y_est=1.800, y_true=0.108
x=4.000, y_est=16.790, y_true=-0.025
x=2.200, y_est=7.809, y_true=0.093
x=1.700, y_est=5.301, y_true=0.137
x=0.900, y_est=1.300, y_true=0.180
x=2.300, y_est=8.290, y_true=0.173
...
global step 02401 rmse 0000000.02 4.99713 -3.18101
...
x=1.100, y_est=2.313, y_true=2.303
x=3.900, y_est=16.289, y_true=16.301
x=4.900, y_est=21.311, y_true=21.301
x=4.500, y_est=19.292, y_true=19.303
x=0.800, y_est=0.804, y_true=0.804
x=1.800, y_est=5.797, y_true=5.798
x=2.100, y_est=7.292, y_true=7.290
x=3.500, y_est=14.303, y_true=14.272
x=3.700, y_est=15.290, y_true=15.267
x=1.000, y_est=1.800, y_true=1.793
x=4.600, y_est=19.804, y_true=19.760
x=1.600, y_est=4.792, y_true=4.796
x=1.900, y_est=6.305, y_true=6.303
x=1.400, y_est=3.805, y_true=3.815
x=4.000, y_est=16.790, y_true=16.819
x=4.200, y_est=17.814, y_true=17.826
x=3.800, y_est=15.809, y_true=15.831
x=2.700, y_est=10.295, y_true=10.328
x=0.700, y_est=0.296, y_true=0.321
x=2.200, y_est=7.809, y_true=7.815
x=2.000, y_est=6.790, y_true=6.807
Training ends with 2475 steps
end
学習データはシャッフルされてバッチになる.
supervisorがチェックポイントを保存しているので,もう一回学習すると保存済みチェックポイントから学習を再開する.
もう一回学習
$ python3 train.py --train_tfrecords=*_train.tfrecords
['./000_train.tfrecords', './001_train.tfrecords', './002_train.tfrecords', './003_train.tfrecords', './004_train.tfrecords', './005_train.tfrecords', './006_train.tfrecords',
...
'./044_train.tfrecords', './045_train.tfrecords', './046_train.tfrecords', './047_train.tfrecords', './048_train.tfrecords', './049_train.tfrecords']
x=2.000, y_est=6.790, y_true=6.820
global step 02208 rmse 0000000.02 5.00433 -3.18855
x=0.300, y_est=-1.696, y_true=-1.690
x=3.900, y_est=16.289, y_true=16.302
x=0.600, y_est=-0.178, y_true=-0.201
x=3.300, y_est=13.299, y_true=13.274
x=0.200, y_est=-2.200, y_true=-2.203
x=3.600, y_est=14.800, y_true=14.767
x=0.400, y_est=-1.193, y_true=-1.204
x=0.800, y_est=0.804, y_true=0.798
...
global step 04608 rmse 0000000.02 5.00579 -3.20404
...
x=3.300, y_est=13.299, y_true=13.295
x=3.000, y_est=11.810, y_true=11.809
x=0.500, y_est=-0.684, y_true=-0.688
x=2.600, y_est=9.818, y_true=9.820
x=0.700, y_est=0.296, y_true=0.315
x=3.400, y_est=13.796, y_true=13.822
x=2.900, y_est=11.314, y_true=11.317
x=1.100, y_est=2.313, y_true=2.309
x=3.600, y_est=14.800, y_true=14.801
x=2.800, y_est=10.811, y_true=10.800
x=2.400, y_est=8.811, y_true=8.796
x=1.500, y_est=4.308, y_true=4.298
x=3.800, y_est=15.809, y_true=15.799
Training ends with 2471 steps
end
そしてテスト
テストのバッチはシャッフルしないので(string_input_producerのshuffleがFalse),入力xが先頭から順番に評価される.
そしてテスト
$ python3 train.py --train_tfrecords=*_test.tfrecords --test --num_epochs=1
['./000_test.tfrecords', './001_test.tfrecords', './002_test.tfrecords', './003_test.tfrecords', './004_test.tfrecords', './005_test.tfrecords', './006_test.tfrecords', './007_test.tfrecords',
...
'./042_test.tfrecords', './043_test.tfrecords', './044_test.tfrecords', './045_test.tfrecords', './046_test.tfrecords', './047_test.tfrecords', './048_test.tfrecords', './049_test.tfrecords']
x=5.100, y_est=22.314, y_true=22.277
x=5.200, y_est=22.800, y_true=22.776
x=5.300, y_est=23.312, y_true=23.276
x=5.400, y_est=23.798, y_true=23.775
x=5.500, y_est=24.306, y_true=24.274
x=5.600, y_est=24.805, y_true=24.774
x=5.700, y_est=25.309, y_true=25.273
x=5.800, y_est=25.792, y_true=25.773
x=5.900, y_est=26.306, y_true=26.272
x=6.000, y_est=26.796, y_true=26.772
x=6.100, y_est=27.303, y_true=27.271
x=6.200, y_est=27.802, y_true=27.771
x=6.300, y_est=28.311, y_true=28.270
x=6.400, y_est=28.785, y_true=28.770
x=6.500, y_est=29.285, y_true=29.269
x=6.600, y_est=29.786, y_true=29.769
x=6.700, y_est=30.304, y_true=30.268
x=6.800, y_est=30.794, y_true=30.768
x=6.900, y_est=31.296, y_true=31.267
x=7.000, y_est=31.792, y_true=31.767
x=7.100, y_est=32.321, y_true=32.266
x=7.200, y_est=32.806, y_true=32.766
x=7.300, y_est=33.288, y_true=33.265
x=7.400, y_est=33.802, y_true=33.765
x=7.500, y_est=34.301, y_true=34.264
x=7.600, y_est=34.788, y_true=34.764
x=7.700, y_est=35.300, y_true=35.263
x=7.800, y_est=35.793, y_true=35.763
x=7.900, y_est=36.291, y_true=36.262
x=8.000, y_est=36.822, y_true=36.761
x=8.100, y_est=37.295, y_true=37.261
x=8.200, y_est=37.801, y_true=37.760
x=8.300, y_est=38.304, y_true=38.260
x=8.400, y_est=38.799, y_true=38.759
x=8.500, y_est=39.311, y_true=39.259
x=8.600, y_est=39.797, y_true=39.758
x=8.700, y_est=40.318, y_true=40.258
x=8.800, y_est=40.795, y_true=40.757
x=8.900, y_est=41.314, y_true=41.257
x=9.000, y_est=41.790, y_true=41.756
x=9.100, y_est=42.297, y_true=42.256
x=9.200, y_est=42.789, y_true=42.755
x=9.300, y_est=43.297, y_true=43.255
x=9.400, y_est=43.803, y_true=43.754
x=9.500, y_est=44.311, y_true=44.254
x=9.600, y_est=44.805, y_true=44.753
x=9.700, y_est=45.296, y_true=45.253
x=9.800, y_est=45.796, y_true=45.752
x=9.900, y_est=46.303, y_true=46.252
Training ends with 49 steps
end