Edited at

[TensorFlow] 学習の途中経過をセーブしたりリロードしたりする

More than 3 years have passed since last update.


はじめに

TensorFlowで学習途中の重みやバイアスをセーブし、セーブしたデータをリストアして続きを学習するというのをやってみました。


セーブ

saverのsave()メソッドにセッション(session)とファイルパス(file_path)を指定すると、その時点のセッションの情報(変数 etc.)がファイルに保存されます。

import tensorflow as tf

saver = tf.train.Saver()
saver.save(session, file_path)

ファイルパスはディレクトリを含めた指定が可能です。

ただし、存在しないディレクトリは自分で作成する必要があります。

ディレクトリ名をsess、ファイル名をfilenameにしたい場合、file_pathの指定は"sess/filename"になります。この場合、sessディレクトリには以下の3ファイルが作成されます。


  • checkpoint

  • filename

  • filename.meta


リストア

saverのrestore()メソッドにセッション(session)とファイルパス(file_path)を指定すると、保存してあったセッションの情報(変数 etc.)がリストアされます。

saver = tf.train.Saver()

saver.restore(sess, sess_path)

指定したパスにファイルがなければエラーが発生しますので、ファイルが存在するかどうかのチェックが必要です。


プログラム

mnist_softmax.pyにセーブとリロードの機能を追加してみます。

初回はセーブしたデータがないのでプログラムで初期化します。

2回目以降はセーブしたデータでプログラムを初期化します。


mnist.py

#! /usr/bin/env python3

# -*- coding: utf-8 -*-
#
# This program derived from mnist_softmax.py
# tensorflow/examples/tutorials/mnist/mnist_softmax.py

import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Define constants
data_dir = "./data/mnist/"
sess_dir = "sess"
sess_file = "sess.info"
imagesize = 28
n_label = 10
n_batch = 100
n_train = 1000
learning_rate = 0.5

# Read MNIST data
mnist = input_data.read_data_sets(data_dir, one_hot=True)

# Create the model
x = tf.placeholder(tf.float32, [None, imagesize ** 2])
W = tf.get_variable("W", [imagesize ** 2, n_label], initializer=tf.random_normal_initializer())
b = tf.get_variable("b", [n_label], initializer=tf.constant_initializer(0.0))
y = tf.matmul(x, W) + b

# Loss and Optimizer
y_ = tf.placeholder(tf.float32, [None, n_label])
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

with tf.Session() as sess:
# Graph
# $ tensorboard --logdir=tensorboard
tf.train.SummaryWriter('tensorboard', sess.graph)

# Initialize / Restore variables
saver = tf.train.Saver()
sess_path = os.path.join(sess_dir, sess_file)
if os.path.exists(sess_path):
saver.restore(sess, sess_path)
else:
sess.run(tf.initialize_all_variables())

# Train model
for _ in range(n_train):
batch_xs, batch_ys = mnist.train.next_batch(n_batch)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

# Test trained model
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
result = sess.run(accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels})
print('Accuracy => ' + str(result))
print('')

# Save
if os.path.exists(sess_dir) == False:
os.makedirs(sess_dir)
saver.save(sess, sess_path)



実行

前回の実行結果の続きから学習を再開するため、徐々に精度が向上していきます。

(データの解凍の出力は省略)

$ ./mnist.py

Accuracy => 0.8785

$ ./mnist.py
Accuracy => 0.8961

$ ./mnist.py
Accuracy => 0.9015

$ ./mnist.py
Accuracy => 0.9028

$ ./mnist.py
Accuracy => 0.9039

$ ./mnist.py
Accuracy => 0.9068

$ ./mnist.py
Accuracy => 0.9131


おわりに

TensorFlowは、Saverを使えばセーブもリロードも1行書けばOKです。

学習済みのデータでモデルを動作させるときに必ず使うことになる操作なので、Saverの使い方はマスターしておきたいですね。