LoginSignup
145
156

More than 5 years have passed since last update.

[TF]Tensorflowの学習パラメータの保存と読み込み方法

Posted at

Tensorflow で学習したパラメータの保存と読み込みには、tf.train.Saverを使用します。

保存

保存をするときは、作成したsaver classのsaveメソッドを使用します。

python
saver = tf.train.Saver()

なんらかの処理

#保存
saver.save(sess, "model.ckpt")

保存は学習の最後でもいいし、学習の途中のタイミングでもいいです。

読み込み

読み込みをするときは、作成したsaver classのrestoreメソッドを使用します。
sessionが必要なので、sessionを作成した後に読み込みます。
ipython上で実行する場合は、tf.InteractiveSession()で、通常はtf.Session()でsessionを作成します。

python
sess=tf.InteractiveSession()

saver.restore(sess, "model.ckpt")

実際に保存と読み込みを行った様子が下記になります。

流れは下記のようになっています。

1.モデル作成
2.学習
3.あとで比較するためにパラメータを別の変数に保存
4.パラメータをファイルに保存
5.Session Close
6.Session作成
7.初期化(本来はこれは必要ありません。比較するためにわざと初期化しました。)
8.保存したパラメータと比較(これはひとつ前で初期化したので差がでます。)
9.ファイルからパラメータを読み込む
10.保存したパラメータと比較(これは一致します)
11.学習

TF_SaveAndRestoreModel-20-1-html.png

コード

python
# # import

# In[1]:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data


# # load dataset

# In[2]:

mnist = input_data.read_data_sets("./data/mnist/", one_hot=True) 


# # build model

# In[3]:

def mlp(x, output_dim, reuse=False):

    w1 = tf.get_variable("w1", [x.get_shape()[1], 1024], initializer=tf.random_normal_initializer())
    b1 = tf.get_variable("b1", [1024], initializer=tf.constant_initializer(0.0))
    w2 = tf.get_variable("w2", [1024, output_dim], initializer=tf.random_normal_initializer())
    b2 = tf.get_variable("b2", [output_dim], initializer=tf.constant_initializer(0.0))

    fc1 = tf.nn.relu(tf.matmul(x, w1) + b1)
    fc2 = tf.matmul(fc1, w2) + b2

    return fc2, [w1, b1, w2, b2]

def slp(x, output_dim):
    w1 = tf.get_variable("w1", [x.get_shape()[1], output_dim], initializer=tf.random_normal_initializer())
    b1 = tf.get_variable("b1", [output_dim], initializer=tf.constant_initializer(0.0))

    fc1 = tf.nn.relu(tf.matmul(x, w1) + b1)
    return fc1, [w1, b1]

n_batch = 32
n_label = 10
n_train = 10000
imagesize = 28
learning_rate = 0.5

x_node = tf.placeholder(tf.float32, shape=(n_batch, imagesize*imagesize))
y_node = tf.placeholder(tf.float32, shape=(n_batch, n_label))

with tf.variable_scope("MLP") as scope:
    out_m, theta_m = mlp(x_node, n_label)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(out_m, y_node))
opt  = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
tr_pred = tf.nn.softmax(out_m)

test_data = mnist.test.images
test_labels = mnist.test.labels
tx = tf.constant(test_data)
ty_ = tf.constant(test_labels)

with tf.variable_scope("MLP") as scope:
    scope.reuse_variables()
    ty, _ = mlp(tx, n_label)

te_pred = tf.nn.softmax(ty) 


# In[4]:

def accuracy(y, y_):
    return 100.0 * np.sum(np.argmax(y, 1) == np.argmax(y_, 1)) / y.shape[0]


# In[5]:

saver = tf.train.Saver()

sess=tf.InteractiveSession()

init = tf.initialize_all_variables()
sess.run(init)


# In[6]:

for step in xrange(n_train):
    bx, by = mnist.train.next_batch(n_batch)
    feed_dict = {x_node: bx, y_node: by}
    _, _loss, _tr_pred = sess.run([opt, loss, tr_pred], feed_dict=feed_dict)
    if step % 500 == 0:
        _accuracy = accuracy(_tr_pred, by)
        print 'step = %d, loss=%.2f, accuracy=%.2f' % (step, _loss, _accuracy)

print 'test accuracy=%.2f' % accuracy(te_pred.eval(), test_labels)


# In[8]:

old_theta_m = [ p.eval() for p in theta_m] # for comparing


# In[9]:

saver.save(sess, "model.ckpt")


# In[10]:

sess.close()


# In[11]:

sess=tf.InteractiveSession()

# for clear
init = tf.initialize_all_variables()
sess.run(init)


# In[12]:

for prm, prm_o in zip(theta_m, old_theta_m):
    p1 = prm.eval()
    p2 = prm_o
    print np.sum(p1 != p2) 


# In[13]:

saver.restore(sess, "model.ckpt")


# In[14]:

for prm, prm_o in zip(theta_m, old_theta_m):
    p1 = prm.eval()
    p2 = prm_o
    print np.sum(p1 != p2) 


# In[15]:

print 'test accuracy=%.2f' % accuracy(te_pred.eval(), test_labels)


# In[16]:

for step in xrange(n_train):
    bx, by = mnist.train.next_batch(n_batch)
    feed_dict = {x_node: bx, y_node: by}
    _, _loss, _tr_pred = sess.run([opt, loss, tr_pred], feed_dict=feed_dict)
    if step % 500 == 0:
        _accuracy = accuracy(_tr_pred, by)
        print 'step = %d, loss=%.2f, accuracy=%.2f' % (step, _loss, _accuracy)

print 'test accuracy=%.2f' % accuracy(te_pred.eval(), test_labels)


# In[17]:

sess.close()


# In[ ]:

tf.reset_default_graph()
145
156
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
145
156