13
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlowで学習したモデルをProtocol Bufferで出力する方法

Posted at

はじめに

TensorFlowでモデルを学習させた時はどのように保存していますか?
saver.restore()で保存して.ckptファイルとして出力することが多いと思います。
PC上で動かす場合はそれで良いですが、Androidなど、他の用途に使いたい場合はProtocol Buffer(.pbファイル)で出力することが必要になることがあります。

この記事では、MNIST for ML Beginnersのチュートリアルを例に.pbファイルで出力していきます。

MNIST for ML Beginnersを学習

まずはTensorFlowのインポートとMNISTの読み込みから

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import tensorflow as tf

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

次に適当にgraph1を作ってそこでモデルを構築します。
graphを使っていること以外はほぼそのままです。

graph1 = tf.Graph()
with graph1.as_default():
    x = tf.placeholder(tf.float32, [None, 784])
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))

    y = tf.nn.softmax(tf.matmul(x, W) + b)
    
    y_ = tf.placeholder(tf.float32, [None, 10])

    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    train_step = tf.train.GradientDescentOptimizer(0.05).minimize(cross_entropy)

    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    sess = tf.InteractiveSession()

    tf.global_variables_initializer().run()
    
    for _ in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

0.9006

1000ミニバッチ学習した結果、Accuracyは90%程度でした。
ここまではチュートリアルほぼそのままですね。

Protocol Buffer形式で保存

ここからが本題で、別のgraphに、先ほどのVariableを全てConstantに変更したモデルを用意します。
学習済みなのでVariableにしておく必要はありません。Constantにしましょう。
また、nameはきちんと書いてください。
基本的に、ここでの変数名x_2などではなくnameを使って呼び出します。

tf.train.write_graph(graph,".","mnist.pb",as_text=False)

この部分でgraphをカレントディレクトリ"."に"mnist.pb"という名前で保存しています。

graph = tf.Graph()
with graph.as_default():
    x_2 = tf.placeholder(tf.float32, [1, 784],name="input")
    W_2 = tf.constant(sess.run(W),name="Weight")
    b_2 = tf.constant(sess.run(b),name="bias")
    y_2 = tf.nn.softmax(tf.matmul(x_2, W_2) + b_2,name="output")
    sess2 = tf.Session()
    
    sess2.run(tf.global_variables_initializer())
    
    tf.train.write_graph(graph,".","mnist.pb",as_text=False)

これで保存されているはずです。
Python3/TensorFlow1.3/Macの環境でしか試していないので、他環境で動かなければご連絡ください。

13
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?