LoginSignup
7

More than 5 years have passed since last update.

TensorflowでMNISTデータ認識+CNN(Convolutional Neural Net)+学習結果セーブ&ロード

Last updated at Posted at 2016-11-04

はじめに

今回は非力なマシンで学習をしてる途中を保存→再開→再開→...→テストということを想定したTensorflowのセーブ&ロード機能を簡単にまとめよう。
# fine-tuningするする詐欺中
# 学習後のパタメータの公開とか、過学習の有無を別のマシンで確認するとか。いろいろ応用可能な便利機能。

個別にセッション中のパラメータを保存する方法もあるが、今回は全部保存する方法に限定する。
前述の通り、Tensorflowのセーブ&ロード機能である。そのため、保存できるのはTensorflowで宣言できるパラメータのみだ。
個人的に必要なパラメータは個人でpklとか使って保存しよう。

なお、参考にしたサイトはここである。
# 同じことでハマったのでとても助かった。

ソース

実行順序は以下の通り。
(1)mnist_CNN_Graph_adhoc_with_saver.py:このソースで学習&パラメータをセーブ
(2)mnist_CNN_Graph_adhoc_with_restore.py:このソースでパラメータをロード&テスト

このソースは必須!
ExtendedTensorflowCNN.py:適当設計クラス

実装の概要

処理内容は、以前の投稿そのままである。
# 変更点は、結果をすぐに確認できるように、学習回数を10回にしたぐらいだ。

実行環境

ざっくり以下の環境。
・Mac OS X 10.10.5
・Python 3.5.1
・virtualenv
・IPython

処理内容(特筆すべき処理)

セーブするために必要なこと

mnist_CNN_Graph_adhoc_with_saver.py を紐解いていく。

名前をつけよう(nameを使うべし!)

以下のように、セーブしたいパラメータには name='x' のように名前をつけておこう。
別のプログラムでロードするとき、この名前と同じにしていれば自動的に値がリストアされるのだ。

mnist_CNN_Graph_adhoc_with_saver.py
    # make input and output
    with tf.name_scope('input') as scope:
        x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
        x_image = tf.reshape(x, [-1,28,28,1], name='x-pixel_order')
    with tf.name_scope('teach') as scope:
        y_ = tf.placeholder(tf.float32, shape=[None, 10], name='d')
    ##
    # そのほかのパラメータは ExtendedTensorflowCNN.py で宣言している。
    ##

パラメータをセーブする準備

mnist_CNN_Graph_adhoc_with_saver.py
    # create saver
    saver = tf.train.Saver()

パラメータをセーブする処理

以下の処理を実行すると、指定したファイル名でパラメータが保存される。
学習の途中を保存したい場合は、 global_step で step数を指定すればいい。
すると、"指定したファイル名-step数" と "指定したファイル名-step数.meta" が保存される。

mnist_CNN_Graph_adhoc_with_saver.py
saver_file_name = 'save/mnist-CNN.saver.tf'
# ...(途中省略)...
    saver.save(sess, saver_file_name, global_step=step+1) 

ロードするために必要なこと

mnist_CNN_Graph_adhoc_with_restore.py を紐解いていく。

もう一度名前をつけよう!

保存したパラメータをロードするためには、保存したときと同じ名前のパラメータが必要である。

mnist_CNN_Graph_adhoc_with_restore.py
    # make input and output
    with tf.name_scope('input') as scope:
        x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
        x_image = tf.reshape(x, [-1,28,28,1], name='x-pixel_order')
    with tf.name_scope('teach') as scope:
        y_ = tf.placeholder(tf.float32, shape=[None, 10], name='d')
    ##
    # そのほかのパラメータは ExtendedTensorflowCNN.py で宣言している。
    ##

パラメータをロードする準備&処理

mnist_CNN_Graph_adhoc_with_restore.py
        # create saver & restore 
        # if you wanna use new Tensor, then you must use "tf.initialize_variables()" to initialize it.
        saver = tf.train.Saver()
        saver.restore(sess, saver_file_name + "-10")            

処理結果

セーブしたファイルの中を見てみよう!

mnist-CNN.saver.tf-10.meta

スクリーンショット 2016-11-04 13.19.30.png

"input" の中に "x" や "x-pixel_order"、"teach" の中に "d" といった Tensorboardでも確認できたパラメータの構造が保存されている。
# 値は読めないけど。

mnist-CNN.saver.tf-10

スクリーンショット 2016-11-04 13.33.48.png

下のように、ExtendedTensorflowCNN.py で宣言したパラメータ(学習で変化するパラメータ)などの値がツラツラと保存されている。
# 値は読めないけど。

ExtendedTensorflowCNN.py
class ExtendedTensorflowCNN():

#    def __init__(self):

    def inference(self, input_placeholder):
        with tf.name_scope('first_convolutional_layer') as scope:
            W_conv1 = weight_variable([5, 5, 1, 32])
            b_conv1 = bias_variable([32])
            h_conv1 = tf.nn.relu(conv2d(input_placeholder, W_conv1) + b_conv1)
            h_pool1 = max_pool_2x2(h_conv1)

        # ... (途中省略) ...

        # Readout Layer
        with tf.name_scope('Readout_Layer') as scope:
            W_fc2 = weight_variable([1024, 10])
            b_fc2 = bias_variable([10])
            self.y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

        # ... (以下省略) ...

テストを実行しよう!

今回は2回、mnist_CNN_Graph_adhoc_with_restore.py を実行した。
その結果を以下に貼り付ける。

スクリーンショット 2016-11-04 13.45.26.png

スクリーンショット 2016-11-04 13.45.40.png

(学習を10回しかしていないので正答率は気にせず)テスト結果を比較すると、2回とも同じスコアだった。
無事、セーブ&ロード機能の確認終了である。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7