やりたいこと
色んな画像に対してObject Detection APIで
sess.run()したかったが、一枚5~8秒ほどかかった。
# イメージ
for i in range(10):
tf.Session(graph=hoge) as sess:
sess.run(hoge_op, feed_dict={})
実際はクラス化したり、sessionがネストしてたり、
外部から呼び出されたりでもっと複雑だったが、
まぁ雰囲気はこんな感じだった。
対処
要はこう。
# イメージ
tf.Session(graph=hoge) as sess:
for i in range(10):
sess.run(hoge_op, feed_dict={})
何故か。グラフとセッションの関係
この2つの概念を知っておけば、自分のような初心者でも何となくTensorflowのコードが理解出来るようになった。
グラフは回路のようなもので、セッションでの実行で初めて通電するイメージで認識している。
- グラフ:計算グラフ
情報系の人間には見覚えのあるグラフ。
TensorFlowで書く時は↓みたいな感じ。(あ、コードは上のグラフとは関係ないです。)
電源の無い回路のようなもので、セッションによって通電されないと各変数(tf.placeholder)に値が入らない。
add_graph = tf.Graph()
with add_graph.as_default():
a = tf.placeholder(tf.int32, shape=[], name="a")
b = tf.placeholder(tf.int32, shape=[], name="b")
add_op = tf.add(a, b, name="add_op")
- セッション:
グラフを実行するために存在する。
コードによっては省略されているが、tf.Session(graph=hoge)のように、実行対象のグラフは必ず指定されている。
tf.Session()でグラフを指定したsessを作成出来て、sess.run()でその中の一部とか全部を出力出来る。
with tf.Session(graph=add_graph) as sess:
ret = sess.run(add_op, feed_dict={a:1,b:1})
print ret
例えばadd_opのような、グラフ内の変数名を指定すれば、feed_dictを入力として起動したグラフにおける、該当の変数(add_op)を返り値として得られる。
複数欲しいなら[a, b, add_op]でOK。
- 実装の流れ:
- グラフを作成
- セッションを作成→sess.run()
結果として、何故遅かったのか
最初の例だと、毎回セッションの作成を行ってしまうから。
セッションは作成時にメモリ確保とかを行うので、作成にはかなり時間がかかる。
参考サイト
グラフとセッションの部分のコード参考:
http://docs.fabo.io/tensorflow/building_graph/tensorflow_graph_part2.html
グラフとセッションについての説明:
https://arakan-pgm-ai.hatenablog.com/entry/2017/05/04/173031
感想
Chainerとかは直感的だけど、TensorFlowは結構独特で、
ごくごく軽く触る程度の身としては困っている。
勉強(機械学習)のための勉強(Tensorflowへの慣れ)、
になってしまったので、
やはりChainerみたいなdefine by run系がとっつきやすいと思う。