環境
OS
Ubuntu 18.04
Docker
REPOSITORY: tensorflow/tensorflow
TAG: latest-gpu-py3-jupyter
IMAGE ID: 88178d65d12c
※上記Dockerは、Tensorflow2.0、TensorBoard 2.0.0
今回はeager executionモデル*をtf.saved_model.save
で保存したpbファイルで試した。
*: 例えばtutorialのモデル。
時間がない人は、pb2tensorboardのまとめへ。
試したこと
tensorflow githubのissue #8854 に記載されていた内容
jubjamie commented on 31 Mar 2017
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename ='PATH_TO_PB.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def)
LOGDIR='YOUR_LOG_LOCATION'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
上記のコードは、v1用のコードのため、v2では動かない。
(tf.Session()など対応していないものがあるため。)
そこで、tf_upgrade_v2
コマンドを使って、v1のコードをv2のコードに変換して実行。
【参考】https://www.tensorflow.org/guide/upgrade
結果、graph_def.ParseFromString(f.read())
で下記のエラーメッセージが出る。
google.protobuf.message.DecodeError: Error parsing message
brandondutra commented on 1 Apr 2017
上記、デコードエラーに対して、解決案のコメントがあった。
import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
with tf.Session() as sess:
model_filename ='saved_model.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
#print(sm)
if 1 != len(sm.meta_graphs):
print('More than one graph found. Not sure which to write')
sys.exit(1)
#graph_def = tf.GraphDef()
#graph_def.ParseFromString(sm.meta_graphs[0])
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
LOGDIR='YOUR_LOG_LOCATION'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
上記のコードは、v1用のコードのため、v2では動かないので、同様に、tf_upgrade_v2
を使って変換。
変換後のコードがこちら。
import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
with tf.compat.v1.Session() as sess:
model_filename ='saved_model.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
#print(sm)
if 1 != len(sm.meta_graphs):
print('More than one graph found. Not sure which to write')
sys.exit(1)
#graph_def = tf.GraphDef()
#graph_def.ParseFromString(sm.meta_graphs[0])
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
LOGDIR='YOUR_LOG_LOCATION'
train_writer = tf.compat.v1.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
デコードエラーは解消されたが、下記エラーメッセージによると、tf.summary.FileWriter
(tf.compat.v1.summary.FileWriter
)はeager executionに対応してない。
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/summary/writer/writer.py", line 360, in __init__
"tf.summary.FileWriter is not compatible with eager execution. "
RuntimeError: tf.summary.FileWriter is not compatible with eager execution. Use tf.contrib.summary instead.
代わりに、tf.contrib.summary
を使うようにあるが、下記のようにエラーメッセージが出る。
# import tensorflow as tfのとき
AttributeError: module 'tensorflow' has no attribute 'contrib'
# import tensorflow.compat.v1 as tfのとき
AttributeError: module 'tensorflow_core.compat.v1' has no attribute 'contrib'
ちなみにv2のtf.summary
はFileWriter
を持っていないので、tf.compat.v1.summary.FileWriter
を使うことになる。
tf.disable_v2_behavior()
を入れる。
tf.summary.FileWriterのエラーメッセージが出なくなる。
(上記エラーメッセージからすると正しい対処の方法ではなさそう。)
この段階で、TensorBoardに読み込ませるlogファイルが生成できる。
しかし、下図のように表示され、グラフが表示されない。
add_graphが、反映されていない。
FileWriter
の引数にsess.graph
を渡す。
train_writer.add_graph(sess.graph)
でグラフを追加するのではなく、
train_writer = tf.compat.v1.summary.FileWriter(LOGDIR, sess.graph)
として、一括で行う。
あるいは、下記のようにwithを使っても解決できる。
with tf.summary.FileWriter(LOGDIR) as writer:
writer.add_graph(sess.graph)
pb2tensorboardのまとめ
以上の試行をまとめて、サンプルコード(pb2tensorboard.py)を作成した。
【使い方】
※logの出力先ディレクトリは実行前に作成してはいけない。実行時に作成される。
$ python pb2tensorboard.py --in_pb <pbファイルのパス> --out_log_dir <TensorBoardにインポートするlogの出力先>
【コード】
import os
import sys
import argparse
import tensorflow.compat.v1 as tf
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--in_pb", help="input pb file path.",
type=str, default=None)
parser.add_argument(
"--out_log_dir", help='output log directory.', type=str, default=os.path.join('.', 'log'))
args = parser.parse_args()
tf.disable_v2_behavior()
with tf.Session() as sess:
with gfile.GFile(args.in_pb, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
if 1 != len(sm.meta_graphs):
print('More than one graph found. Not sure which to write')
sys.exit(1)
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
train_writer = tf.summary.FileWriter(args.out_log_dir, sess.graph)
if __name__ == '__main__':
main()
つぶやき
- 実は、Tensorflow公式githubには
import_pb_to_tensorboard.py
が用意されているが、うまく動かない(19.12.2)。 - もっとスマートな方法があれば、教えてください。