はじめに
TensorFlowで学習途中の重みやバイアスをセーブし、セーブしたデータをリストアして続きを学習するというのをやってみました。
セーブ
saverのsave()メソッドにセッション(session)とファイルパス(file_path)を指定すると、その時点のセッションの情報(変数 etc.)がファイルに保存されます。
import tensorflow as tf
saver = tf.train.Saver()
saver.save(session, file_path)
ファイルパスはディレクトリを含めた指定が可能です。
ただし、存在しないディレクトリは自分で作成する必要があります。
ディレクトリ名をsess
、ファイル名をfilename
にしたい場合、file_pathの指定は"sess/filename"
になります。この場合、sess
ディレクトリには以下の3ファイルが作成されます。
- checkpoint
- filename
- filename.meta
リストア
saverのrestore()メソッドにセッション(session)とファイルパス(file_path)を指定すると、保存してあったセッションの情報(変数 etc.)がリストアされます。
saver = tf.train.Saver()
saver.restore(sess, sess_path)
指定したパスにファイルがなければエラーが発生しますので、ファイルが存在するかどうかのチェックが必要です。
プログラム
mnist_softmax.pyにセーブとリロードの機能を追加してみます。
初回はセーブしたデータがないのでプログラムで初期化します。
2回目以降はセーブしたデータでプログラムを初期化します。
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
#
# This program derived from mnist_softmax.py
# tensorflow/examples/tutorials/mnist/mnist_softmax.py
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# Define constants
data_dir = "./data/mnist/"
sess_dir = "sess"
sess_file = "sess.info"
imagesize = 28
n_label = 10
n_batch = 100
n_train = 1000
learning_rate = 0.5
# Read MNIST data
mnist = input_data.read_data_sets(data_dir, one_hot=True)
# Create the model
x = tf.placeholder(tf.float32, [None, imagesize ** 2])
W = tf.get_variable("W", [imagesize ** 2, n_label], initializer=tf.random_normal_initializer())
b = tf.get_variable("b", [n_label], initializer=tf.constant_initializer(0.0))
y = tf.matmul(x, W) + b
# Loss and Optimizer
y_ = tf.placeholder(tf.float32, [None, n_label])
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
with tf.Session() as sess:
# Graph
# $ tensorboard --logdir=tensorboard
tf.train.SummaryWriter('tensorboard', sess.graph)
# Initialize / Restore variables
saver = tf.train.Saver()
sess_path = os.path.join(sess_dir, sess_file)
if os.path.exists(sess_path):
saver.restore(sess, sess_path)
else:
sess.run(tf.initialize_all_variables())
# Train model
for _ in range(n_train):
batch_xs, batch_ys = mnist.train.next_batch(n_batch)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# Test trained model
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
result = sess.run(accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels})
print('Accuracy => ' + str(result))
print('')
# Save
if os.path.exists(sess_dir) == False:
os.makedirs(sess_dir)
saver.save(sess, sess_path)
実行
前回の実行結果の続きから学習を再開するため、徐々に精度が向上していきます。
(データの解凍の出力は省略)
$ ./mnist.py
Accuracy => 0.8785
$ ./mnist.py
Accuracy => 0.8961
$ ./mnist.py
Accuracy => 0.9015
$ ./mnist.py
Accuracy => 0.9028
$ ./mnist.py
Accuracy => 0.9039
$ ./mnist.py
Accuracy => 0.9068
$ ./mnist.py
Accuracy => 0.9131
おわりに
TensorFlowは、Saverを使えばセーブもリロードも1行書けばOKです。
学習済みのデータでモデルを動作させるときに必ず使うことになる操作なので、Saverの使い方はマスターしておきたいですね。