TensorFlow のモデル保存アレコレ 1

  • 13
    いいね
  • 0
    コメント

TensorFlow のモデル保存ってほとんど全てのユーザーが通る道なのに、結構細かい仕組みまで知っておかないとよくわからない部分が多くて意外とややこしいと思ったので、躓きやすい部分(もしくは自分が実際に躓いた部分)を中心にまとめてみようかと思いました。

Variable の値を save/restore するときのアレコレ

tf.train.Saver のコンストラクタについて

以下のコードを実行するとエラーが発生します。

import tensorflow as tf

# Save variables
with tf.Graph().as_default() as g1:
    v0 = tf.Variable(0, name="variable_0")
    saver1 = tf.train.Saver()  # variable_0 を save/restore するオペレーションがグラフ g1 追加される
    v1 = tf.Variable(1, name="variable_1")
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess1:
        sess1.run(init_op)
        saver1.save(sess1, "model/sample1/export")  # variable_1 は保存されない

# Restore variables
with tf.Graph().as_default() as g2:
    v0 = tf.Variable(10, name="variable_0")
    v1 = tf.Variable(100, name="variable_1")
    saver2 = tf.train.Saver()  # variable_0, variable_1 を save/restore するオペレーションがグラフ g2 に追加される
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess2:
        sess2.run(init_op)
        saver2.restore(sess2, "model/sample1/export")  # variable_0, variable_1 を両方とも restore しようとする
W tensorflow/core/framework/op_kernel.cc:975] Not found: Key variable_1 not found in checkpoint

解説

tf.Saver はコンストラクタが実行された時に、 default graph に対してその時点で存在する Variable を save したり restore したりするオペレーションを追加します。

なので、 saver1 は variable_0 だけを保存したのに対して saver2 は variable_0 と variable_1 を restore しようとしています。

variable_1 を restore するオペレーションを実行しようとした時に、ファイル側にそんなもん保存されていないぞ!というエラーが発生したわけです。

教訓

tf.Saver は保存したい全ての Variable を定義した後に定義しましょう。

保存されたファイルから一部のデータだけを restore

先程はファイルに存在しないデータを restore しようとしてエラーが起こりましたが、ファイルに余分なデータが存在する場合は無視してくれます。

import tensorflow as tf

# Save variables
with tf.Graph().as_default() as g1:
    v0 = tf.Variable(0, name="variable_0")
    v1 = tf.Variable(1, name="variable_1")
    saver1 = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess1:
        sess1.run(init_op)
        saver1.save(sess1, "model/sample2/export")

# Restore variables
with tf.Graph().as_default() as g2:
    v3 = tf.Variable(10, name="variable_0")
    saver2 = tf.train.Saver()
    with tf.Session() as sess2:
        saver2.restore(sess2, "model/sample2/export")
        print sess2.run(v3)
0

解説

保存した variable_0 の値が正しく restore されていますね。
variable_1 の値もファイルに保存されていますが、 restore するときグラフに同じ名前の Variable が存在しないので無視されています。

教訓

取り出したい Variable の名前だけを把握しておけば、ファイルに余分なデータが含まれていても必要な値を部分的に取り出すことができます。

tf.train.Saver の珍プレー

最後に、最高に意味不明な珍プレーを紹介しましょう。

# -*- coding: utf-8 -*-

import tensorflow as tf

# グラフ g1 は作ったけど使わない
with tf.Graph().as_default() as g1:
    v0 = tf.Variable(0, name="variable_0")
    saver1 = tf.train.Saver()

with tf.Graph().as_default() as g2:
    v1 = tf.Variable(10, name="variable_1")
    saver2 = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess2:
        sess2.run(init_op)
        saver1.save(sess2, "model/sample3/export")  # saver1 を使っているのは typo じゃない

with tf.Graph().as_default() as g3:
    v2 = tf.Variable(100, name="variable_1")
    saver3 = tf.train.Saver()
    with tf.Session() as sess3:
        saver1.restore(sess3, "model/sample3/export")  # saver1 を使っているのは typo じゃない
        print sess3.run(v2)
10

解説

お分かり頂けただろうか。
グラフ g1 のために作成した saver1 で、グラフ g2 やグラフ g3 に対して Variable を save したり restore できているのです。

なぜこんな挙動をするのかはソースコードを読まないと確信を持っては言えませんが、 tf.train.Saver 次のような仕様なのでしょう。

  • saver1 はグラフ g1 の save/restore をするオペレーションへの直接の参照を持っているわけではない
  • saver1.save を実行する時にセッションを引数で渡すので、渡したセッションに紐付いたグラフから save/restore のオペレーションを探して全て実行している

つまり、 saver2 を作ったときのコンストラクタでグラフ g2 に save/restore のオペレーションが追加され、 saver1.savesess2 を渡して実行した時に sess2.graph に含まれる全ての save/restore のオペレーションが検索されて実行されたと思われます。

言ってて意味わからないですね。
あなたに届け、この想い。

次回予告

Variable の保存については(伝わったかはともかく)大体言いたいことを言えました。
今度はグラフの保存について書こうと思っています。

たとえば学習の時には dropout の placeholder を入れたけど、モデルを保存するときには除きたいとか。

あとは、 Google Cloud Machine Learning にデプロイする時は入出力のオペレーションを collections に入れて楽に取り出せるようにしているので、最近は自分もそれに倣って保存したモデルを読み込んだ後に呼び出す必要があるオペレーションとかは collections に突っ込むようにしているとか。

そんな話です。

この投稿は TensorFlow Advent Calendar 201623日目の記事です。