[TF]Tensorflowの学習パラメータの保存と読み込み方法

  • 97
    いいね
  • 0
    コメント
この記事は最終更新日から1年以上が経過しています。

Tensorflow で学習したパラメータの保存と読み込みには、tf.train.Saverを使用します。

保存

保存をするときは、作成したsaver classのsaveメソッドを使用します。

python
saver = tf.train.Saver()

なんらかの処理

#保存
saver.save(sess, "model.ckpt")

保存は学習の最後でもいいし、学習の途中のタイミングでもいいです。

読み込み

読み込みをするときは、作成したsaver classのrestoreメソッドを使用します。
sessionが必要なので、sessionを作成した後に読み込みます。
ipython上で実行する場合は、tf.InteractiveSession()で、通常はtf.Session()でsessionを作成します。

python
sess=tf.InteractiveSession()

saver.restore(sess, "model.ckpt")

実際に保存と読み込みを行った様子が下記になります。

流れは下記のようになっています。

1.モデル作成
2.学習
3.あとで比較するためにパラメータを別の変数に保存
4.パラメータをファイルに保存
5.Session Close
6.Session作成
7.初期化(本来はこれは必要ありません。比較するためにわざと初期化しました。)
8.保存したパラメータと比較(これはひとつ前で初期化したので差がでます。)
9.ファイルからパラメータを読み込む
10.保存したパラメータと比較(これは一致します)
11.学習

TF_SaveAndRestoreModel-20-1-html.png

コード

python
# # import

# In[1]:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data


# # load dataset

# In[2]:

mnist = input_data.read_data_sets("./data/mnist/", one_hot=True) 


# # build model

# In[3]:

def mlp(x, output_dim, reuse=False):

    w1 = tf.get_variable("w1", [x.get_shape()[1], 1024], initializer=tf.random_normal_initializer())
    b1 = tf.get_variable("b1", [1024], initializer=tf.constant_initializer(0.0))
    w2 = tf.get_variable("w2", [1024, output_dim], initializer=tf.random_normal_initializer())
    b2 = tf.get_variable("b2", [output_dim], initializer=tf.constant_initializer(0.0))

    fc1 = tf.nn.relu(tf.matmul(x, w1) + b1)
    fc2 = tf.matmul(fc1, w2) + b2

    return fc2, [w1, b1, w2, b2]

def slp(x, output_dim):
    w1 = tf.get_variable("w1", [x.get_shape()[1], output_dim], initializer=tf.random_normal_initializer())
    b1 = tf.get_variable("b1", [output_dim], initializer=tf.constant_initializer(0.0))

    fc1 = tf.nn.relu(tf.matmul(x, w1) + b1)
    return fc1, [w1, b1]

n_batch = 32
n_label = 10
n_train = 10000
imagesize = 28
learning_rate = 0.5

x_node = tf.placeholder(tf.float32, shape=(n_batch, imagesize*imagesize))
y_node = tf.placeholder(tf.float32, shape=(n_batch, n_label))

with tf.variable_scope("MLP") as scope:
    out_m, theta_m = mlp(x_node, n_label)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(out_m, y_node))
opt  = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
tr_pred = tf.nn.softmax(out_m)

test_data = mnist.test.images
test_labels = mnist.test.labels
tx = tf.constant(test_data)
ty_ = tf.constant(test_labels)

with tf.variable_scope("MLP") as scope:
    scope.reuse_variables()
    ty, _ = mlp(tx, n_label)

te_pred = tf.nn.softmax(ty) 


# In[4]:

def accuracy(y, y_):
    return 100.0 * np.sum(np.argmax(y, 1) == np.argmax(y_, 1)) / y.shape[0]


# In[5]:

saver = tf.train.Saver()

sess=tf.InteractiveSession()

init = tf.initialize_all_variables()
sess.run(init)


# In[6]:

for step in xrange(n_train):
    bx, by = mnist.train.next_batch(n_batch)
    feed_dict = {x_node: bx, y_node: by}
    _, _loss, _tr_pred = sess.run([opt, loss, tr_pred], feed_dict=feed_dict)
    if step % 500 == 0:
        _accuracy = accuracy(_tr_pred, by)
        print 'step = %d, loss=%.2f, accuracy=%.2f' % (step, _loss, _accuracy)

print 'test accuracy=%.2f' % accuracy(te_pred.eval(), test_labels)


# In[8]:

old_theta_m = [ p.eval() for p in theta_m] # for comparing


# In[9]:

saver.save(sess, "model.ckpt")


# In[10]:

sess.close()


# In[11]:

sess=tf.InteractiveSession()

# for clear
init = tf.initialize_all_variables()
sess.run(init)


# In[12]:

for prm, prm_o in zip(theta_m, old_theta_m):
    p1 = prm.eval()
    p2 = prm_o
    print np.sum(p1 != p2) 


# In[13]:

saver.restore(sess, "model.ckpt")


# In[14]:

for prm, prm_o in zip(theta_m, old_theta_m):
    p1 = prm.eval()
    p2 = prm_o
    print np.sum(p1 != p2) 


# In[15]:

print 'test accuracy=%.2f' % accuracy(te_pred.eval(), test_labels)


# In[16]:

for step in xrange(n_train):
    bx, by = mnist.train.next_batch(n_batch)
    feed_dict = {x_node: bx, y_node: by}
    _, _loss, _tr_pred = sess.run([opt, loss, tr_pred], feed_dict=feed_dict)
    if step % 500 == 0:
        _accuracy = accuracy(_tr_pred, by)
        print 'step = %d, loss=%.2f, accuracy=%.2f' % (step, _loss, _accuracy)

print 'test accuracy=%.2f' % accuracy(te_pred.eval(), test_labels)


# In[17]:

sess.close()


# In[ ]:

tf.reset_default_graph()