問題
今、tensorflowを使ったウェブアプリを作っています。複数種類のモデルで識別の結果を出したい。それらのモデルは個別に学習させて、学習結果を別々に保存し、サーバ起動時に、それを読みこんで動かそうと思いました。
ざっくりとコードの断片を書くと
ai/util.py
class MyModel():
# モデルの管理をしたりするクラス
ai/ai1.py
# モデルを作成
model = MyModel(モデルの詳細)
if __name__ == '__main__':
# 学習段階なら学習し、save/ai1.ckptに学習結果を保存する
(ai2, ai3も同様に)
app.py
from ai.ai1 imoort model as model1
from ai.ai2 imoort model as model2
from ai.ai3 imoort model as model3
# 学習結果を読み込む
model1.restore()
model2.restore()
model3.restore()
のように、別々に学習させて、サーバ起動時に読み込む。
しかし、このようにすると、ai2の学習結果を読み込むmodel2.restoreのところでエラーが出る。model2には◯◯って変数がないよって言われます。もちろん、一つだけなら問題ない。
解決
何が原因かというと、すべてのモデルを同じGraphに書き込んでしまっていることがいけなかった。tensorflowはデフォルトで1つのGraphが用意されており、それに書き込んでいくということになっている。importするごとにそこにモデルが書き込まれるわけだが、当然、3つのモデルが混ざっていて、個別にする学習のときのGraphと異なる構成になってしまう。
これを解決するには、別のGraphで読み込めばいい。なので、例えば、app.pyの該当部分を
app.py
with tf.Graph().as_default():
from ai.ai1 imoort model as model1
model1.restore()
with tf.Graph().as_default():
from ai.ai2 imoort model as model2
model2.restore()
with tf.Graph().as_default():
from ai.ai3 imoort model as model3
model3.restore()
のようにwith tf.Graph().as_default():
で囲ってしまえば良い。