LoginSignup
13
10

More than 5 years have passed since last update.

tf.Graphの使い所

Posted at

問題

今、tensorflowを使ったウェブアプリを作っています。複数種類のモデルで識別の結果を出したい。それらのモデルは個別に学習させて、学習結果を別々に保存し、サーバ起動時に、それを読みこんで動かそうと思いました。

ざっくりとコードの断片を書くと

ai/util.py
class MyModel():
    # モデルの管理をしたりするクラス
ai/ai1.py
# モデルを作成
model = MyModel(モデルの詳細)
if __name__ == '__main__':
    # 学習段階なら学習し、save/ai1.ckptに学習結果を保存する

(ai2, ai3も同様に)

app.py
from ai.ai1 imoort model as model1
from ai.ai2 imoort model as model2
from ai.ai3 imoort model as model3

# 学習結果を読み込む
model1.restore()
model2.restore()
model3.restore()

のように、別々に学習させて、サーバ起動時に読み込む。
しかし、このようにすると、ai2の学習結果を読み込むmodel2.restoreのところでエラーが出る。model2には◯◯って変数がないよって言われます。もちろん、一つだけなら問題ない。

解決

何が原因かというと、すべてのモデルを同じGraphに書き込んでしまっていることがいけなかった。tensorflowはデフォルトで1つのGraphが用意されており、それに書き込んでいくということになっている。importするごとにそこにモデルが書き込まれるわけだが、当然、3つのモデルが混ざっていて、個別にする学習のときのGraphと異なる構成になってしまう。

これを解決するには、別のGraphで読み込めばいい。なので、例えば、app.pyの該当部分を

app.py
with tf.Graph().as_default():
    from ai.ai1 imoort model as model1
    model1.restore()
with tf.Graph().as_default():
    from ai.ai2 imoort model as model2
    model2.restore()
with tf.Graph().as_default():
    from ai.ai3 imoort model as model3
    model3.restore()

のようにwith tf.Graph().as_default():で囲ってしまえば良い。

13
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
10