32
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

DiverseAdvent Calendar 2018

Day 7

TensorFlow で学習したモデルのグラフを `tf.train.import_meta_graph` でロードする

Last updated at Posted at 2018-12-06

この記事はなに

TensorFlow のちょっとした Tips です。
モデルに学習済みパラメータをロードするときのことについて書きます。

空いていた(?)ので、Diverse Advent Calendar 2018に12/7の記事としてお邪魔しました。ゆかりのある人ならOKらしい。
昨日、12/6は @imaizume さんによる アプリのロジックからisIPhoneXフラグを消すためにやったこと でした。以前一緒に働いていた身として大変申し訳無い気持ちになりつつ、とても分かりやすく起こった問題を解説されていてためになりました。
明日、12/8は @python_spameggs さんによる記事です。最後にお会いしたのが結構前なので元気そうな記事だと嬉しいです(私信)。

本題です。

TensorFlow で学習したモデルのグラフを tf.train.import_meta_graph からロードする

tf.train.import_meta_graph を使うことで、checkpoint 保存時にできる .meta ファイルからネットワークをロードすることが出来ます。
これにより、モデルの入出力のインターフェースがわかっていれば、モデルのインスタンスを作ることなく推論などを行うことが出来ます。
(チーム内で私の先生的なポジションの方に教えてもらいました :bow: )

以下で、モデルクラスを使ったロードと、 tf.train.import_meta_graph を使ったロードでそれぞれ違いを見てみます。
内容そんなにないですが使った Notebook を一応置いておきます

前提のモデル

例なので簡単な mnist 用のモデルを用意します(低レベルAPIですみません)。

class Model:
    
    def __init__(self, config):
        # placeholder
        self.config = config
        self.x = tf.placeholder(dtype=tf.float32, shape=[None, config.num_inputs], name='inputs')
        self.y = tf.placeholder(dtype=tf.float32, shape=[None, config.num_outputs], name='outputs')
        self.is_training = tf.placeholder(dtype=tf.bool, name='is_training')

        outputs = self.build(self.x)
        self.loss = tf.reduce_mean(-tf.reduce_sum(self.y * tf.log(outputs + 10e-8), 1))
        y_label = tf.argmax(self.y, 1)
        outputs_label = tf.argmax(outputs, 1)
        correct = tf.equal(y_label, outputs_label)
        self.accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name='accuracy')
        

    def build(self, x):
        layer = tf.layers.dropout(x, rate=self.config.dropout_rate, training=self.is_training)
        for l in range(self.config.num_layers):
            layer = tf.layers.dense(layer, self.config.num_units, tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.1), name='layer_{}'.format(l))
            layer = tf.layers.dropout(layer, rate=self.config.dropout_rate , training=self.is_training, name='layer_dropout_{}'.format(l))

        out = tf.layers.dense(layer,
                              self.config.num_outputs, 
                              tf.nn.softmax, 
                              kernel_initializer=tf.truncated_normal_initializer(stddev=0.1), 
                              name='out_layer')
        return out

次のような感じで mnist のデータを使ってざっとトレーニングします。

model = Model()
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(model.loss, global_step=global_step)

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(5000):
        train_data, train_label = mnist.train.next_batch(batch_size)
        _, loss = sess.run([train_op, model.loss], feed_dict={
            model.x: train_data,
            model.y: train_label,
            model.is_training: True
        })
        step = sess.run(global_step)
        if step % 500 == 0:
            print('{} step loss: {}'.format(step, loss))
    saver.save(sess, 'ckpt/model.ckpt', step)

終了後、ckpt dir を見ると次のように学習したパラメータなどが checkpoint として保存されています。

$ ls ckpt
checkpoint  model.ckpt-5000.data-00000-of-00001  model.ckpt-5000.index  model.ckpt-5000.meta

ここから、保存された checkpoint を使ってモデルをロードし、テストデータの推論を行います。

モデルクラスを使ったロード

Model のインスタンスを作成し、saver から restore を行った後、 Model のインスタンスを通じて session.run や placeholder の指定を行っています。

tf.reset_default_graph()
model = Model()
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt_path = tf.train.latest_checkpoint('ckpt/')
    saver.restore(sess, ckpt_path)
    res = sess.run(model.accuracy, feed_dict={
        model.x: test_data,
        model.y: test_label,
        model.is_training: False
    })
print('accuracy: ', res)

> INFO:tensorflow:Restoring parameters from ckpt/model.ckpt-5000
> accuracy:  0.9793

tf.train.import_meta_graph を使ったロード

Model のインスタンスを作成せずに、 tf.train.import_meta_graph を使ってモデル構造を import しています。
placeholder や run の対象は、モデルのインスタンスがないので string で inputs:0 のように指定しています。

tf.reset_default_graph()
with tf.Session() as sess:
    ckpt_path = tf.train.latest_checkpoint('ckpt/')
    saver = tf.train.import_meta_graph(ckpt_path + '.meta')
    saver.restore(sess, ckpt_path)
    res = sess.run('accuracy:0', feed_dict={
        'inputs:0': test_data,
        'outputs:0': test_label,
        'is_training:0': False
    })
print('accuracy: ', res)

> INFO:tensorflow:Restoring parameters from ckpt/model.ckpt-5000
> accuracy:  0.9793

string で指定する以外にも、次のように対象となる Tensor を get_tensor_by_name を使って取得することも出来ます。

is_training = tf.get_default_graph().get_tensor_by_name('is_training:0')
inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
outputs = tf.get_default_graph().get_tensor_by_name('outputs:0')
accuracy = tf.get_default_graph().get_tensor_by_name('accuracy:0')
res = sess.run(accuracy, feed_dict={
    inputs: test_data,
    outputs: test_label,
    is_training: False
})
> INFO:tensorflow:Restoring parameters from ckpt/model.ckpt-5000
> accuracy:  0.9793

.meta からのロードのメリット・デメリット

メリット

一番の理由は、**「モデル構造と学習したパラメータの対応を意識する必要がない」**ことでした。
(※これ関連で困っていたときに教えていただきました。)

学習済みモデルを使って推論を行う場合、当たり前ですが次の2つがちゃんと一致している必要があります。

  • モデルのグラフ構造
  • 学習したパラメータ

例えば、4層の DNN に5層の DNN で学習したパラメータを突っ込んでも失敗するわけです。
もちろんパラメータをアレコレして一部を使ったり置き換えたりで無理やり使ったりすることは出来ますが、今は「学習したモデルをそのまま推論に使いたい」というシーンでの話なので、これは意図した挙動ではありません。

モデルのインスタンスを使ったロードの場合、グラフ構造を class 内に定義してトレーニングを行い、合わせて出力される学習済みパラメータを保存します。
推論の際には、再度モデルクラスを読み込むことでグラフを作成し、先ほど保存したパラメータをモデルにロードすることで推論が行えるようになります。

# 再掲 > モデルクラスを使ったロード
# ... 略 ...
model = Model()
# ... 略 ...
    saver.restore(sess, ckpt_path)
    res = sess.run(model.accuracy, feed_dict={
        model.x: test_data,
        model.y: test_label,
        model.is_training: False
    })

この時、 Model のインスタンス作成時に作られるグラフと対応した checkpoint が ckpt_path にないとエラーになります。
すなわち、 Model と checkpoint のパスは何かしらの方法で一緒に管理する必要が出てきます。

モデルが1つだったり、少数なら問題なさそうですね。
また、Production 環境ならちゃんとモデルとパラメータのミスマッチが起きない設計で管理すべきです。
しかし、手元で色々と実験をしていてモデル数が増えてきて、ちょっと動作を見比べたい、などのときには少しややこしかったりします。
例えば、ハイパーパラメータのチューニングでモデルを複数作っていたりすると辛くなってきます。

そこで、checkpoint 保存時に一緒に作られる .meta ファイルを使ってモデルのグラフ構造も一緒にロードすることで、管理を一箇所にまとめることが出来ます。

# 再掲 > `tf.train.import_meta_graph` を使ったロード

with tf.Session() as sess:
    ckpt_path = tf.train.latest_checkpoint('ckpt/')
    saver = tf.train.import_meta_graph(ckpt_path + '.meta')
    saver.restore(sess, ckpt_path)
    res = sess.run('accuracy:0', feed_dict={
        'inputs:0': test_data,
        'outputs:0': test_label,
        'is_training:0': False
    })

tf.train.import_meta_graph を使うことで、推論に Model クラスを必要としていません。
代わりに、 string でノードの name を指定しています。
これにより、モデルの中身が何であろうと( units が何個でも、 layer が何層であっても、全く違う構造であっても)、入出力さえ定義されていれば推論を行うことが出来るようになります。
もちろん、 .meta ファイルと他の checkpoint のファイルが一緒に保存されている必要がありますが、これらは何もしなければ同じ場所に保存されるので、ミスマッチが起きることは基本ないでしょう。

これにより、モデルを変えても推論コードをいじる必要がなく、疎結合な推論コードを用意することが出来ます。

デメリット

いくつかあります。

string で Tensor を指定しないといけない

ミスでバグを生む可能性があります。見栄えもちょっと良くないかも。
入出力のインターフェースが決まっているなら、それ用の Struct などを用意しておくと良さそうです。

class ModelProtocol(NamedTuple):
    inputs: str = 'input:0'
    outputs: str = 'outputs:0'
    is_training: str = 'is_training:0'
    accuracy: str = 'accuracy:0'

グラフに何があるのかが見えづらくなる

tf.train.import_meta_graph を使うと現在のグラフ上にノードが読み込まれます。
モデルクラスを定義した場合のように元となるソースがあるわけではないので、どのようなノードがあるのか分かりづらくなり、グラフ内での名前の衝突などが起きやすくなることが考えられます。

簡単にできる名前空間の対策としては、tf.train.import_meta_graph の引数 import_scope= を指定することで、ロードする変数のスコープを指定できます。

saver = tf.train.import_meta_graph(ckpt_path + '.meta', import_scope='my_model_01')

衝突は避けやすくなりますが、どのようなノードがあるのか見えづらいのは変わらないので、扱う際は気をつけましょう。

まとめ

TensorFlow でモデルをトレーニングし、保存したパラメータ (checkpoint) を推論時などにロードするための方法として、 tf.train.import_meta_graph を紹介しました。
基本的にはモデルを定義したクラスを使う方法で問題ないのですが、多くのモデルとその checkpoint で実験をしている場合、インターフェース定義だけ合っていればロードして推論ができるのはメリットになることを紹介しました。

なにか間違いなどありましたらお気軽にご指摘いただけると幸いです :bow:

32
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
32
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?