Help us understand the problem. What is going on with this article?

【Flask+Keras】サーバーで複数モデルを高速で推論させる方法

結論

keras==2.2.4
tensorflow=1.14.0
numpy==1.16.4

テストコード

from flask import Flask
import time

import numpy as np
import tensorflow as tf
from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img

app = Flask(__name__)

model_path1 = "mnist.h5"
model1 = load_model(model_path1)
label1 = ["l0", "l1", "l2", "l3", "l4", "l5", "l6", "l7", "l8", "l9"]
model1._make_predict_function()#<めっちゃ重要>predictの高速化
graph1 = tf.get_default_graph()


model_path2 = "mnist.h5"
model2 = load_model(model_path2)
label2 = ["l0", "l1", "l2", "l3", "l4", "l5", "l6", "l7", "l8", "l9"]
model2._make_predict_function()
graph2 = tf.get_default_graph()

def model1_predict(img_path):
    img = img_to_array(load_img(img_path, target_size=(28, 28), grayscale=True))
    img_nad = img_to_array(img) / 255
    img_nad = img_nad[None, ...]
    global graph1
    with graph1.as_default():
        pred = model1.predict(img_nad, batch_size=1, verbose=0)
    score = np.max(pred)
    pred_label = label1[np.argmax(pred[0])]
    print("スコア:", score, "ラベル:", pred_label)

def model2_predict(img_path):
    img = img_to_array(load_img(img_path, target_size=(28, 28), grayscale=True))
    img_nad = img_to_array(img) / 255
    img_nad = img_nad[None, ...]
    global graph2
    with graph2.as_default():
        pred = model2.predict(img_nad, batch_size=1, verbose=0)
    score = np.max(pred)
    pred_label = label2[np.argmax(pred[0])]
    print("スコア:", score, "ラベル:", pred_label)

@app.route("/", methods=['GET', 'POST'])
def webapp():
    start1 = time.time()
    model1_predict("mnist_test.jpg")
    end1 = time.time()-start1
    print("処理時間<model1>: ", end1, "秒")

    start2 = time.time()
    model2_predict("mnist_test.jpg")
    end2 = time.time() - start2
    print("処理時間<model2>: ", end2, "秒")

    output = "<p>model1:"+str(round(end1, 3))+"秒</p><br><p>model2:"+str(round(end2, 3))+"秒</p>"
    return output

if __name__ == "__main__":
    app.run(port=5000, debug=False)

重要な部分

model1 = load_model(model_path1)
model1._make_predict_function()#<めっちゃ重要>predictの高速化
graph1 = tf.get_default_graph()

def model1_predict():
    global graph1
    with graph1.as_default():
        pred = model1.predict(***, batch_size=1, verbose=0)
osakasho
PythonでDeepLearningしたい人生。 ちょっとまともなGPUのPCを購入してゴリゴリ機械学習させてゆく。
https://shomyapp.net/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away