LoginSignup
11
9

More than 3 years have passed since last update.

[tensorflow] modelのsave,load,run

Last updated at Posted at 2018-03-20

modelのsave

tensorflowでmodelをsaveする方法は二つある。check_pointとsaved_model。

check_point

check_pointはEstimatorにRunconfigを渡すことで可能。何分でcheck_pointを取るか設定可能。train途中に中止してもcheck_pointを読み込むことでtrainを続けることが可能。定義しなければdefaultの設定に従う。saved_modelを使った方が簡単になるし、saved_model_cliを使うことも可能。

saved_model

saved_modelはsession内でSavedModelBuilderを使って自分でbuildするか、Estimatorの関数export_savedmodelを使う。以下はexport_savedmodelの例。

example
def model_fn(features, labels, mode, params):
  model = Model(params['data_format'])
  image = features
  # saved_modelをloadしてserveする場合
  if isinstance(image, dict):
    image = features['image']
  if mode == tf.estimator.ModeKeys.PREDICT:
    logits = model(image, training=False)
  ...
mnist_classifier = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir=model_dir,
    params={
        'data_format': 'channels_last'
    })
# input
image = tf.placeholder(tf.float32, shape=[None, 28, 28], name='image')
# input receiver
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
    'image': image,
})
# export model
mnist_classifier.export_savedmodel(model_dir, input_fn)

modelのloadとrun

saved_model_cli

saved_modelの中身を見るコマンド

command
saved_model_cli show --dir ./mnist_model/1521441078 --all
result
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['classify']:
The given SavedModel SignatureDef contains the following input(s):
inputs['image'] tensor_info:
    dtype: DT_FLOAT
    shape: (-1, 28, 28)
    name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['classes'] tensor_info:
    dtype: DT_INT64
    shape: (-1)
    name: ArgMax:0
outputs['probabilities'] tensor_info:
    dtype: DT_FLOAT
    shape: (-1, 10)
    name: Softmax:0
Method name is: tensorflow/serving/predict

signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['image'] tensor_info:
    dtype: DT_FLOAT
    shape: (-1, 28, 28)
    name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['classes'] tensor_info:
    dtype: DT_INT64
    shape: (-1)
    name: ArgMax:0
outputs['probabilities'] tensor_info:
    dtype: DT_FLOAT
    shape: (-1, 10)
    name: Softmax:0
Method name is: tensorflow/serving/predict

上のnameに注目、Placeholder:0がinput test image、ArgMax:0がouput classes

saved_modelを使ってnpy(numpy saved data)ファイルのテストデータのラベルを予測する。

command
saved_model_cli run --dir ./mnist_model/1521441078 --tag_set serve --signature_def classify --inputs image=./example2.npy
result
Result for output key classes:
[7 2 1 ... 9 8 6]
Result for output key probabilities:
[[5.71908802e-03 5.26234088e-03 5.65170124e-03 ... 8.65403712e-01
  2.58122981e-02 4.70778644e-02]
 [3.00555006e-02 1.64930541e-02 6.53899252e-01 ... 9.86633589e-04
  3.90466303e-02 4.75284224e-03]
 [6.55172905e-03 7.44415641e-01 3.45991701e-02 ... 2.27554981e-02
  5.24941944e-02 3.27617601e-02]
 ...
 [6.86948374e-03 3.37781794e-02 1.31903710e-02 ... 1.70341298e-01
  1.24142714e-01 3.21431905e-01]
 [6.56285435e-02 3.99611443e-02 3.55339721e-02 ... 8.27240124e-02
  2.49543816e-01 1.30154327e-01]
 [1.08888358e-01 8.15967447e-04 1.35427341e-01 ... 3.61581246e-04
  7.29427906e-03 4.25529340e-03]]

Result for output key classes = ArgMax:0
Result for output key probabilities = Softmax:0

saved_modelをsessionで読み込み、pythonコードで結果をみる例

example
import tensorflow as tf
import numpy as np

export_dir = './mnist_model/1521510868'

te = np.load('example2.npy')
te0 = np.reshape(te[0,:], (-1, 28, 28))

with tf.Session(graph=tf.Graph()) as sess:
  # saved_model load
  tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)
  # input
  i = sess.graph.get_tensor_by_name("image:0")
  # output
  r = sess.graph.get_tensor_by_name("ArgMax:0")
  print(sess.run(r, feed_dict={i:te0}))
11
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
9