この記事はなに
TensorFlow のちょっとした Tips です。
モデルに学習済みパラメータをロードするときのことについて書きます。
空いていた(?)ので、Diverse Advent Calendar 2018に12/7の記事としてお邪魔しました。ゆかりのある人ならOKらしい。
昨日、12/6は @imaizume さんによる アプリのロジックからisIPhoneXフラグを消すためにやったこと でした。以前一緒に働いていた身として大変申し訳無い気持ちになりつつ、とても分かりやすく起こった問題を解説されていてためになりました。
明日、12/8は @python_spameggs さんによる記事です。最後にお会いしたのが結構前なので元気そうな記事だと嬉しいです(私信)。
本題です。
TensorFlow で学習したモデルのグラフを tf.train.import_meta_graph
からロードする
tf.train.import_meta_graph
を使うことで、checkpoint 保存時にできる .meta
ファイルからネットワークをロードすることが出来ます。
これにより、モデルの入出力のインターフェースがわかっていれば、モデルのインスタンスを作ることなく推論などを行うことが出来ます。
(チーム内で私の先生的なポジションの方に教えてもらいました )
以下で、モデルクラスを使ったロードと、 tf.train.import_meta_graph
を使ったロードでそれぞれ違いを見てみます。
内容そんなにないですが使った Notebook を一応置いておきます 。
前提のモデル
例なので簡単な mnist 用のモデルを用意します(低レベルAPIですみません)。
class Model:
def __init__(self, config):
# placeholder
self.config = config
self.x = tf.placeholder(dtype=tf.float32, shape=[None, config.num_inputs], name='inputs')
self.y = tf.placeholder(dtype=tf.float32, shape=[None, config.num_outputs], name='outputs')
self.is_training = tf.placeholder(dtype=tf.bool, name='is_training')
outputs = self.build(self.x)
self.loss = tf.reduce_mean(-tf.reduce_sum(self.y * tf.log(outputs + 10e-8), 1))
y_label = tf.argmax(self.y, 1)
outputs_label = tf.argmax(outputs, 1)
correct = tf.equal(y_label, outputs_label)
self.accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name='accuracy')
def build(self, x):
layer = tf.layers.dropout(x, rate=self.config.dropout_rate, training=self.is_training)
for l in range(self.config.num_layers):
layer = tf.layers.dense(layer, self.config.num_units, tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.1), name='layer_{}'.format(l))
layer = tf.layers.dropout(layer, rate=self.config.dropout_rate , training=self.is_training, name='layer_dropout_{}'.format(l))
out = tf.layers.dense(layer,
self.config.num_outputs,
tf.nn.softmax,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),
name='out_layer')
return out
次のような感じで mnist のデータを使ってざっとトレーニングします。
model = Model()
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(model.loss, global_step=global_step)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(5000):
train_data, train_label = mnist.train.next_batch(batch_size)
_, loss = sess.run([train_op, model.loss], feed_dict={
model.x: train_data,
model.y: train_label,
model.is_training: True
})
step = sess.run(global_step)
if step % 500 == 0:
print('{} step loss: {}'.format(step, loss))
saver.save(sess, 'ckpt/model.ckpt', step)
終了後、ckpt
dir を見ると次のように学習したパラメータなどが checkpoint として保存されています。
$ ls ckpt
checkpoint model.ckpt-5000.data-00000-of-00001 model.ckpt-5000.index model.ckpt-5000.meta
ここから、保存された checkpoint を使ってモデルをロードし、テストデータの推論を行います。
モデルクラスを使ったロード
Model
のインスタンスを作成し、saver から restore を行った後、 Model
のインスタンスを通じて session.run
や placeholder の指定を行っています。
tf.reset_default_graph()
model = Model()
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt_path = tf.train.latest_checkpoint('ckpt/')
saver.restore(sess, ckpt_path)
res = sess.run(model.accuracy, feed_dict={
model.x: test_data,
model.y: test_label,
model.is_training: False
})
print('accuracy: ', res)
> INFO:tensorflow:Restoring parameters from ckpt/model.ckpt-5000
> accuracy: 0.9793
tf.train.import_meta_graph
を使ったロード
Model
のインスタンスを作成せずに、 tf.train.import_meta_graph
を使ってモデル構造を import しています。
placeholder や run の対象は、モデルのインスタンスがないので string で inputs:0
のように指定しています。
tf.reset_default_graph()
with tf.Session() as sess:
ckpt_path = tf.train.latest_checkpoint('ckpt/')
saver = tf.train.import_meta_graph(ckpt_path + '.meta')
saver.restore(sess, ckpt_path)
res = sess.run('accuracy:0', feed_dict={
'inputs:0': test_data,
'outputs:0': test_label,
'is_training:0': False
})
print('accuracy: ', res)
> INFO:tensorflow:Restoring parameters from ckpt/model.ckpt-5000
> accuracy: 0.9793
string で指定する以外にも、次のように対象となる Tensor を get_tensor_by_name
を使って取得することも出来ます。
is_training = tf.get_default_graph().get_tensor_by_name('is_training:0')
inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
outputs = tf.get_default_graph().get_tensor_by_name('outputs:0')
accuracy = tf.get_default_graph().get_tensor_by_name('accuracy:0')
res = sess.run(accuracy, feed_dict={
inputs: test_data,
outputs: test_label,
is_training: False
})
> INFO:tensorflow:Restoring parameters from ckpt/model.ckpt-5000
> accuracy: 0.9793
.meta
からのロードのメリット・デメリット
メリット
一番の理由は、**「モデル構造と学習したパラメータの対応を意識する必要がない」**ことでした。
(※これ関連で困っていたときに教えていただきました。)
学習済みモデルを使って推論を行う場合、当たり前ですが次の2つがちゃんと一致している必要があります。
- モデルのグラフ構造
- 学習したパラメータ
例えば、4層の DNN に5層の DNN で学習したパラメータを突っ込んでも失敗するわけです。
もちろんパラメータをアレコレして一部を使ったり置き換えたりで無理やり使ったりすることは出来ますが、今は「学習したモデルをそのまま推論に使いたい」というシーンでの話なので、これは意図した挙動ではありません。
モデルのインスタンスを使ったロードの場合、グラフ構造を class 内に定義してトレーニングを行い、合わせて出力される学習済みパラメータを保存します。
推論の際には、再度モデルクラスを読み込むことでグラフを作成し、先ほど保存したパラメータをモデルにロードすることで推論が行えるようになります。
# 再掲 > モデルクラスを使ったロード
# ... 略 ...
model = Model()
# ... 略 ...
saver.restore(sess, ckpt_path)
res = sess.run(model.accuracy, feed_dict={
model.x: test_data,
model.y: test_label,
model.is_training: False
})
この時、 Model
のインスタンス作成時に作られるグラフと対応した checkpoint が ckpt_path
にないとエラーになります。
すなわち、 Model
と checkpoint のパスは何かしらの方法で一緒に管理する必要が出てきます。
モデルが1つだったり、少数なら問題なさそうですね。
また、Production 環境ならちゃんとモデルとパラメータのミスマッチが起きない設計で管理すべきです。
しかし、手元で色々と実験をしていてモデル数が増えてきて、ちょっと動作を見比べたい、などのときには少しややこしかったりします。
例えば、ハイパーパラメータのチューニングでモデルを複数作っていたりすると辛くなってきます。
そこで、checkpoint 保存時に一緒に作られる .meta
ファイルを使ってモデルのグラフ構造も一緒にロードすることで、管理を一箇所にまとめることが出来ます。
# 再掲 > `tf.train.import_meta_graph` を使ったロード
with tf.Session() as sess:
ckpt_path = tf.train.latest_checkpoint('ckpt/')
saver = tf.train.import_meta_graph(ckpt_path + '.meta')
saver.restore(sess, ckpt_path)
res = sess.run('accuracy:0', feed_dict={
'inputs:0': test_data,
'outputs:0': test_label,
'is_training:0': False
})
tf.train.import_meta_graph
を使うことで、推論に Model
クラスを必要としていません。
代わりに、 string でノードの name
を指定しています。
これにより、モデルの中身が何であろうと( units が何個でも、 layer が何層であっても、全く違う構造であっても)、入出力さえ定義されていれば推論を行うことが出来るようになります。
もちろん、 .meta
ファイルと他の checkpoint のファイルが一緒に保存されている必要がありますが、これらは何もしなければ同じ場所に保存されるので、ミスマッチが起きることは基本ないでしょう。
これにより、モデルを変えても推論コードをいじる必要がなく、疎結合な推論コードを用意することが出来ます。
デメリット
いくつかあります。
string で Tensor を指定しないといけない
ミスでバグを生む可能性があります。見栄えもちょっと良くないかも。
入出力のインターフェースが決まっているなら、それ用の Struct などを用意しておくと良さそうです。
class ModelProtocol(NamedTuple):
inputs: str = 'input:0'
outputs: str = 'outputs:0'
is_training: str = 'is_training:0'
accuracy: str = 'accuracy:0'
グラフに何があるのかが見えづらくなる
tf.train.import_meta_graph
を使うと現在のグラフ上にノードが読み込まれます。
モデルクラスを定義した場合のように元となるソースがあるわけではないので、どのようなノードがあるのか分かりづらくなり、グラフ内での名前の衝突などが起きやすくなることが考えられます。
簡単にできる名前空間の対策としては、tf.train.import_meta_graph
の引数 import_scope=
を指定することで、ロードする変数のスコープを指定できます。
saver = tf.train.import_meta_graph(ckpt_path + '.meta', import_scope='my_model_01')
衝突は避けやすくなりますが、どのようなノードがあるのか見えづらいのは変わらないので、扱う際は気をつけましょう。
まとめ
TensorFlow でモデルをトレーニングし、保存したパラメータ (checkpoint) を推論時などにロードするための方法として、 tf.train.import_meta_graph
を紹介しました。
基本的にはモデルを定義したクラスを使う方法で問題ないのですが、多くのモデルとその checkpoint で実験をしている場合、インターフェース定義だけ合っていればロードして推論ができるのはメリットになることを紹介しました。
なにか間違いなどありましたらお気軽にご指摘いただけると幸いです