4
Help us understand the problem. What are the problem?

More than 5 years have passed since last update.

posted at

TensorFlow の tf.train.Saver で GCS 上に直接 Graph を書き出す時の注意点

Exporting and Importing a Metagraph を元にした小ネタです。

tf.train.Saver

TensorFlow の MNIST のチュートリアルを読んでも登場しない tf.train.Saver ですが、学習したモデルを実際に予測に使ってみようと思うと非常に重要なクラスです。
つまり学習に用いた Graph とその変数 Variable の重みなどの学習結果は、学習(および検証/テスト)を行なった Session 内に存在しているので、後で予測のためにそのパラメータを使いたいと思ったら、Graph の情報をファイルに保存しておく必要があります。
また予測を行うプロセスでは別の Session に保存しておいた Variable の内容をロードする必要があります。

Exporting and Importing a Metagraph を読むと、tf.train.Saver というクラスが save(), restore() メソッドを実装していて、Variable の内容の保存、復帰ができることがわかります。
さらに Graph の形状まで完全に再現したい時は export_meta_graph() および import_meta_graph() というメソッドを用いて MetaGraph の保存、復帰を行なうことができます。ただわたしが試した限りだと tf.train.Saver.save() だけでも MetaGraph は保存されているようです。tf.train.Saver.save のドキュメントによると write_meta_graph という引数があり、デフォルトが True なので同時に Meta Graph も保存されるみたいです。
従って読み込み時に tf.train.Saver.restore() が Variable の復元のみ、tf.train.Saver.import_meta_graph() が構築済みの Graph そのものを復元する、というように使いわければいいようです。

また tf.train.Saver は学習中のチェックポイント毎に途中経過を保存するという用途のために作られたみたいで、global_step という引数を渡すことで学習の step 毎にファイル名を変えて保存するといった機能もあるようです。

tf.train.Saver についてももうすこし深掘りしたいところですが、今回はこのくらいに。

TensorFlow のファイル操作と GCS(Google Cloud Storage)

0.11 においては tf.train.Saver に限らず TensorFlow でファイル操作のようにみえるところでは、GCS(Google Cloud Storage) を利用できるようになっているみたいです。

従って tf.train.Saver でも、ディレクトリパスとして gs://my-bucket/checkpoints のような URL を渡すことで、直接 GCS 上にログを保存できました。

すこし TensorFlow のソースコードをみてみると tf.python.platform.gfile というモジュールにファイル操作に関するメソッドのインタフェースが揃えられていて、この先ではファイルパス(URL)をみて適切な方法で具体的な実装が行なわれるようにディスパッチしてくれる、となっているみたいです。このあたりの事情というか方針がわかる開発者向けのドキュメントはどこかにあるといいのですが。

とにかく、TensorFlow を使ってコーディングしている際には、TensorFlow から入出力するファイルに関する操作は gfile モジュールを介するようにしておくと自動的に GCS 対応ができていいかもしれません。

具体的には Google がベータ版として公開している Cloud Machine Learning (以下 Cloud ML)上で TensorFlow で構築したモデルを学習させようと思うと、モデルの情報は GCS 上に保存せざるを得ないので必須になります。

なお、TensorFlow のソースコードを探してみましたが、AWS の S3 に対応するというコードは master ブランチにはみあたらないようです。

tf.train.Saver で GCS 上に直接保存する時の注意点

ところが、tf.train.Saver.save()0.11rc2 の実装では、保存前に親ディレクトリが存在するかどうかをチェックしています。

出力先が gs://.. で始まる時はこれは GCS 上のオブジェクトを指すので、os.makedirs() などでローカルのファイルシステムにディレクトリを作成しても無意味ですので、先ほど書いた通り gfile モジュールを用いて以下のようにしておく必要があります。

import os.path
from tensorflow.python.platform import gfile

saver = tf.train.Saver()

gfile.MakeDirs(log_dir)
saver.save(session, os.path.join(log_dir, "model"))

ところで、出力先がローカルファイルの時はこれは妥当なのですが、GCS ではバケット内のファイルツリーのようにみえているのはみかけだけの話で、いきなり gs://my-bucket/path/to/deep/subdir/myfile のようなネストしたディレクトリの下(のようにみえるパス)のファイル(GCS では正確にはオブジェクト) を書くことは可能なので、このチェックは冗長な気もします。
パスが GCS の URL の場合はこの gfile.IdDirectory)() のチェックは省くようにしてもいいのではないでしょうか。

まとめ

  • tf.train.Saver は Graph の情報を保存/復帰するクラス
  • TensorFlow のファイル操作は GCS も透過的にあつかえる
  • ファイル操作は gfile モジュールを使おう
  • tf.train.Saver.save() の前に gfile.MakeDirs() で作成しよう
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
4
Help us understand the problem. What are the problem?