LoginSignup
7
9

More than 3 years have passed since last update.

pbファイルのグラフをTensorBoardで可視化

Last updated at Posted at 2019-12-01

環境

OS

Ubuntu 18.04

Docker

REPOSITORY: tensorflow/tensorflow
TAG: latest-gpu-py3-jupyter
IMAGE ID: 88178d65d12c

※上記Dockerは、Tensorflow2.0、TensorBoard 2.0.0

今回はeager executionモデル*をtf.saved_model.saveで保存したpbファイルで試した。
*: 例えばtutorialのモデル。

時間がない人は、pb2tensorboardのまとめへ。

試したこと

tensorflow githubのissue #8854 に記載されていた内容

jubjamie commented on 31 Mar 2017

import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename ='PATH_TO_PB.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        g_in = tf.import_graph_def(graph_def)
LOGDIR='YOUR_LOG_LOCATION'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)

上記のコードは、v1用のコードのため、v2では動かない。
(tf.Session()など対応していないものがあるため。)
そこで、tf_upgrade_v2コマンドを使って、v1のコードをv2のコードに変換して実行。
【参考】https://www.tensorflow.org/guide/upgrade
結果、graph_def.ParseFromString(f.read())で下記のエラーメッセージが出る。
google.protobuf.message.DecodeError: Error parsing message

brandondutra commented on 1 Apr 2017

上記、デコードエラーに対して、解決案のコメントがあった。

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

with tf.Session() as sess:
    model_filename ='saved_model.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:

        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
        #print(sm)
        if 1 != len(sm.meta_graphs):
            print('More than one graph found. Not sure which to write')
            sys.exit(1)

        #graph_def = tf.GraphDef()
        #graph_def.ParseFromString(sm.meta_graphs[0])
        g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
LOGDIR='YOUR_LOG_LOCATION'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)

上記のコードは、v1用のコードのため、v2では動かないので、同様に、tf_upgrade_v2を使って変換。
変換後のコードがこちら。

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

with tf.compat.v1.Session() as sess:
    model_filename ='saved_model.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:

        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
        #print(sm)
        if 1 != len(sm.meta_graphs):
            print('More than one graph found. Not sure which to write')
            sys.exit(1)

        #graph_def = tf.GraphDef()
        #graph_def.ParseFromString(sm.meta_graphs[0])
        g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
LOGDIR='YOUR_LOG_LOCATION'
train_writer = tf.compat.v1.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)

デコードエラーは解消されたが、下記エラーメッセージによると、tf.summary.FileWriter(tf.compat.v1.summary.FileWriter)はeager executionに対応してない。

File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/summary/writer/writer.py", line 360, in __init__
    "tf.summary.FileWriter is not compatible with eager execution. "
RuntimeError: tf.summary.FileWriter is not compatible with eager execution. Use tf.contrib.summary instead.

代わりに、tf.contrib.summaryを使うようにあるが、下記のようにエラーメッセージが出る。

# import tensorflow as tfのとき
AttributeError: module 'tensorflow' has no attribute 'contrib'

# import tensorflow.compat.v1 as tfのとき
AttributeError: module 'tensorflow_core.compat.v1' has no attribute 'contrib'

ちなみにv2のtf.summaryFileWriterを持っていないので、tf.compat.v1.summary.FileWriterを使うことになる。

tf.disable_v2_behavior()を入れる。

tf.summary.FileWriterのエラーメッセージが出なくなる。
(上記エラーメッセージからすると正しい対処の方法ではなさそう。)

この段階で、TensorBoardに読み込ませるlogファイルが生成できる。
しかし、下図のように表示され、グラフが表示されない。

Screenshot from 2019-12-02 01-34-25.png

add_graphが、反映されていない。

FileWriterの引数にsess.graphを渡す。

train_writer.add_graph(sess.graph)でグラフを追加するのではなく、
train_writer = tf.compat.v1.summary.FileWriter(LOGDIR, sess.graph)
として、一括で行う。

あるいは、下記のようにwithを使っても解決できる。

with tf.summary.FileWriter(LOGDIR) as writer:
    writer.add_graph(sess.graph)

pb2tensorboardのまとめ

以上の試行をまとめて、サンプルコード(pb2tensorboard.py)を作成した。
【使い方】
※logの出力先ディレクトリは実行前に作成してはいけない。実行時に作成される。

$ python pb2tensorboard.py --in_pb <pbファイルのパス> --out_log_dir <TensorBoardにインポートするlogの出力先>

【コード】

import os
import sys
import argparse
import tensorflow.compat.v1 as tf
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--in_pb", help="input pb file path.",
                        type=str, default=None)
    parser.add_argument(
        "--out_log_dir", help='output log directory.', type=str, default=os.path.join('.', 'log'))
    args = parser.parse_args()
    tf.disable_v2_behavior()
    with tf.Session() as sess:
        with gfile.GFile(args.in_pb, 'rb') as f:
            data = compat.as_bytes(f.read())
            sm = saved_model_pb2.SavedModel()
            sm.ParseFromString(data)
            if 1 != len(sm.meta_graphs):
                print('More than one graph found. Not sure which to write')
                sys.exit(1)
            g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
    train_writer = tf.summary.FileWriter(args.out_log_dir, sess.graph)

if __name__ == '__main__':
    main()

つぶやき

  • 実は、Tensorflow公式githubにはimport_pb_to_tensorboard.pyが用意されているが、うまく動かない(19.12.2)。
  • もっとスマートな方法があれば、教えてください。
7
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
9