TensorFlow のモデル保存ってほとんど全てのユーザーが通る道なのに、結構細かい仕組みまで知っておかないとよくわからない部分が多くて意外とややこしいと思ったので、躓きやすい部分(もしくは自分が実際に躓いた部分)を中心にまとめてみようかと思いました。
Variable の値を save/restore するときのアレコレ
tf.train.Saver のコンストラクタについて
以下のコードを実行するとエラーが発生します。
import tensorflow as tf
# Save variables
with tf.Graph().as_default() as g1:
v0 = tf.Variable(0, name="variable_0")
saver1 = tf.train.Saver() # variable_0 を save/restore するオペレーションがグラフ g1 追加される
v1 = tf.Variable(1, name="variable_1")
init_op = tf.global_variables_initializer()
with tf.Session() as sess1:
sess1.run(init_op)
saver1.save(sess1, "model/sample1/export") # variable_1 は保存されない
# Restore variables
with tf.Graph().as_default() as g2:
v0 = tf.Variable(10, name="variable_0")
v1 = tf.Variable(100, name="variable_1")
saver2 = tf.train.Saver() # variable_0, variable_1 を save/restore するオペレーションがグラフ g2 に追加される
init_op = tf.global_variables_initializer()
with tf.Session() as sess2:
sess2.run(init_op)
saver2.restore(sess2, "model/sample1/export") # variable_0, variable_1 を両方とも restore しようとする
W tensorflow/core/framework/op_kernel.cc:975] Not found: Key variable_1 not found in checkpoint
解説
tf.Saver
はコンストラクタが実行された時に、 default graph に対してその時点で存在する Variable
を save したり restore したりするオペレーションを追加します。
なので、 saver1
は variable_0 だけを保存したのに対して saver2
は variable_0 と variable_1 を restore しようとしています。
variable_1 を restore するオペレーションを実行しようとした時に、ファイル側にそんなもん保存されていないぞ!というエラーが発生したわけです。
教訓
tf.Saver
は保存したい全ての Variable
を定義した後に定義しましょう。
保存されたファイルから一部のデータだけを restore
先程はファイルに存在しないデータを restore しようとしてエラーが起こりましたが、ファイルに余分なデータが存在する場合は無視してくれます。
import tensorflow as tf
# Save variables
with tf.Graph().as_default() as g1:
v0 = tf.Variable(0, name="variable_0")
v1 = tf.Variable(1, name="variable_1")
saver1 = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess1:
sess1.run(init_op)
saver1.save(sess1, "model/sample2/export")
# Restore variables
with tf.Graph().as_default() as g2:
v3 = tf.Variable(10, name="variable_0")
saver2 = tf.train.Saver()
with tf.Session() as sess2:
saver2.restore(sess2, "model/sample2/export")
print sess2.run(v3)
0
解説
保存した variable_0 の値が正しく restore されていますね。
variable_1 の値もファイルに保存されていますが、 restore するときグラフに同じ名前の Variable
が存在しないので無視されています。
教訓
取り出したい Variable
の名前だけを把握しておけば、ファイルに余分なデータが含まれていても必要な値を部分的に取り出すことができます。
tf.train.Saver の珍プレー
最後に、最高に意味不明な珍プレーを紹介しましょう。
# -*- coding: utf-8 -*-
import tensorflow as tf
# グラフ g1 は作ったけど使わない
with tf.Graph().as_default() as g1:
v0 = tf.Variable(0, name="variable_0")
saver1 = tf.train.Saver()
with tf.Graph().as_default() as g2:
v1 = tf.Variable(10, name="variable_1")
saver2 = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess2:
sess2.run(init_op)
saver1.save(sess2, "model/sample3/export") # saver1 を使っているのは typo じゃない
with tf.Graph().as_default() as g3:
v2 = tf.Variable(100, name="variable_1")
saver3 = tf.train.Saver()
with tf.Session() as sess3:
saver1.restore(sess3, "model/sample3/export") # saver1 を使っているのは typo じゃない
print sess3.run(v2)
10
解説
お分かり頂けただろうか。
グラフ g1
のために作成した saver1
で、グラフ g2
やグラフ g3
に対して Variable
を save したり restore できているのです。
なぜこんな挙動をするのかはソースコードを読まないと確信を持っては言えませんが、 tf.train.Saver
次のような仕様なのでしょう。
-
saver1
はグラフg1
の save/restore をするオペレーションへの直接の参照を持っているわけではない -
saver1.save
を実行する時にセッションを引数で渡すので、渡したセッションに紐付いたグラフから save/restore のオペレーションを探して全て実行している
つまり、 saver2
を作ったときのコンストラクタでグラフ g2
に save/restore のオペレーションが追加され、 saver1.save
に sess2
を渡して実行した時に sess2.graph
に含まれる全ての save/restore のオペレーションが検索されて実行されたと思われます。
言ってて意味わからないですね。
あなたに届け、この想い。
次回予告
Variable
の保存については(伝わったかはともかく)大体言いたいことを言えました。
今度はグラフの保存について書こうと思っています。
たとえば学習の時には dropout の placeholder
を入れたけど、モデルを保存するときには除きたいとか。
あとは、 Google Cloud Machine Learning にデプロイする時は入出力のオペレーションを collections に入れて楽に取り出せるようにしているので、最近は自分もそれに倣って保存したモデルを読み込んだ後に呼び出す必要があるオペレーションとかは collections に突っ込むようにしているとか。
そんな話です。