概要
FlaskとKerasを使ってアップロードした画像の識別をするプログラムを動かすと、Cannot interpret feed_dict key as Tensorが出てエラーが出た。
ここにも記載があるが、これはマルチスレッドで動作する場合に、読み込んだモデルが別のスレッドには共有されないため起きる。
Flask + Tensorflow + Keras で モデルからのpredictが動作しない問題
load_modelを使って読み込んだモデルは、tf.get_default_graph()を使ってスレッド間で共有できるようにする必要がある。
例
import tensorflow as tf
from keras.models import load_model
@app.route('/', methods=['GET', 'POST'])
def upload_file():
model = load_model('./sample_model.h5')
X = data
result = model.predict(X)
X = dataの部分は、使うモデルに合わせて必要な型のデータを準備する。
修正前は、画像がアップロードされたタイミングでモデルを読み込んでいた。これだと1回目のアップロードは問題ないが、2回目以降にCannot interpret feed_dict key as Tensorが発生した。
import tensorflow as tf
from keras.models import load_model
model = load_model('./sample_model.h5')
graph = tf.get_default_graph()
@app.route('/', methods=['GET', 'POST'])
def upload_file():
global graph
with graph.as_default():
X = data
result = model.predict(X)
修正内容は、モデルの読み込みをプログラム起動時に実施し、そのグラフをgraphに保存しておく。
ファイルアップロード時にはgraph.as_default()でプログラム起動時に保存したグラフを読み込んで処理する。
こうするとアップロード2回目以降も、Cannot interpret feed_dict key as Tensorは発生しなくなった。
メモ
修正前のプログラムでは、ファイルアップロード時にモデルの読み込みを実施しているので、
問題ないような気がするが、2回目以降のファイルアップロードでエラーが出た。
修正後の方が、モデルのロードは1回で済むのでスッキリするが、なぜ修正前のプログラムでは
エラーとなったのかが、よくわかっていない。
Comments
Let's comment your feelings that are more than good