Help us understand the problem. What is going on with this article?

[TensorFlow] 学習の途中経過をセーブしたりリロードしたりする

More than 3 years have passed since last update.

はじめに

TensorFlowで学習途中の重みやバイアスをセーブし、セーブしたデータをリストアして続きを学習するというのをやってみました。

セーブ

saverのsave()メソッドにセッション(session)とファイルパス(file_path)を指定すると、その時点のセッションの情報(変数 etc.)がファイルに保存されます。

import tensorflow as tf

saver = tf.train.Saver()
saver.save(session, file_path)

ファイルパスはディレクトリを含めた指定が可能です。
ただし、存在しないディレクトリは自分で作成する必要があります。

ディレクトリ名をsess、ファイル名をfilenameにしたい場合、file_pathの指定は"sess/filename"になります。この場合、sessディレクトリには以下の3ファイルが作成されます。

  • checkpoint
  • filename
  • filename.meta

リストア

saverのrestore()メソッドにセッション(session)とファイルパス(file_path)を指定すると、保存してあったセッションの情報(変数 etc.)がリストアされます。

saver = tf.train.Saver()
saver.restore(sess, sess_path)

指定したパスにファイルがなければエラーが発生しますので、ファイルが存在するかどうかのチェックが必要です。

プログラム

mnist_softmax.pyにセーブとリロードの機能を追加してみます。
初回はセーブしたデータがないのでプログラムで初期化します。
2回目以降はセーブしたデータでプログラムを初期化します。

mnist.py
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
#
# This program derived from mnist_softmax.py
# tensorflow/examples/tutorials/mnist/mnist_softmax.py

import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Define constants
data_dir = "./data/mnist/"
sess_dir = "sess"
sess_file = "sess.info"
imagesize = 28
n_label = 10
n_batch = 100
n_train = 1000
learning_rate = 0.5

# Read MNIST data
mnist = input_data.read_data_sets(data_dir, one_hot=True)

# Create the model
x = tf.placeholder(tf.float32, [None, imagesize ** 2])
W = tf.get_variable("W", [imagesize ** 2, n_label], initializer=tf.random_normal_initializer())
b = tf.get_variable("b", [n_label], initializer=tf.constant_initializer(0.0))
y = tf.matmul(x, W) + b

# Loss and Optimizer
y_ = tf.placeholder(tf.float32, [None, n_label])
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

with tf.Session() as sess:
    # Graph
    # $ tensorboard --logdir=tensorboard
    tf.train.SummaryWriter('tensorboard', sess.graph)

    # Initialize / Restore variables
    saver = tf.train.Saver()
    sess_path = os.path.join(sess_dir, sess_file)
    if os.path.exists(sess_path):
        saver.restore(sess, sess_path)
    else:
        sess.run(tf.initialize_all_variables())

    # Train model
    for _ in range(n_train):
        batch_xs, batch_ys = mnist.train.next_batch(n_batch)
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

    # Test trained model
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    result = sess.run(accuracy, feed_dict={x: mnist.test.images,
                                        y_: mnist.test.labels})
    print('Accuracy => ' + str(result))
    print('')

    # Save
    if os.path.exists(sess_dir) == False:
        os.makedirs(sess_dir)
    saver.save(sess, sess_path)

実行

前回の実行結果の続きから学習を再開するため、徐々に精度が向上していきます。
(データの解凍の出力は省略)

$ ./mnist.py
Accuracy => 0.8785

$ ./mnist.py
Accuracy => 0.8961

$ ./mnist.py
Accuracy => 0.9015

$ ./mnist.py
Accuracy => 0.9028

$ ./mnist.py
Accuracy => 0.9039

$ ./mnist.py
Accuracy => 0.9068

$ ./mnist.py
Accuracy => 0.9131

おわりに

TensorFlowは、Saverを使えばセーブもリロードも1行書けばOKです。
学習済みのデータでモデルを動作させるときに必ず使うことになる操作なので、Saverの使い方はマスターしておきたいですね。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away