LoginSignup
0
2

More than 3 years have passed since last update.

【Flask+Keras】サーバーで複数モデルを高速で推論させる方法

Last updated at Posted at 2020-07-28

結論

keras==2.2.4
tensorflow=1.14.0
numpy==1.16.4

テストコード

from flask import Flask
import time

import numpy as np
import tensorflow as tf
from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img

app = Flask(__name__)

model_path1 = "mnist.h5"
model1 = load_model(model_path1)
label1 = ["l0", "l1", "l2", "l3", "l4", "l5", "l6", "l7", "l8", "l9"]
model1._make_predict_function()#<めっちゃ重要>predictの高速化
graph1 = tf.get_default_graph()


model_path2 = "mnist.h5"
model2 = load_model(model_path2)
label2 = ["l0", "l1", "l2", "l3", "l4", "l5", "l6", "l7", "l8", "l9"]
model2._make_predict_function()
graph2 = tf.get_default_graph()

def model1_predict(img_path):
    img = img_to_array(load_img(img_path, target_size=(28, 28), grayscale=True))
    img_nad = img_to_array(img) / 255
    img_nad = img_nad[None, ...]
    global graph1
    with graph1.as_default():
        pred = model1.predict(img_nad, batch_size=1, verbose=0)
    score = np.max(pred)
    pred_label = label1[np.argmax(pred[0])]
    print("スコア:", score, "ラベル:", pred_label)

def model2_predict(img_path):
    img = img_to_array(load_img(img_path, target_size=(28, 28), grayscale=True))
    img_nad = img_to_array(img) / 255
    img_nad = img_nad[None, ...]
    global graph2
    with graph2.as_default():
        pred = model2.predict(img_nad, batch_size=1, verbose=0)
    score = np.max(pred)
    pred_label = label2[np.argmax(pred[0])]
    print("スコア:", score, "ラベル:", pred_label)

@app.route("/", methods=['GET', 'POST'])
def webapp():
    start1 = time.time()
    model1_predict("mnist_test.jpg")
    end1 = time.time()-start1
    print("処理時間<model1>: ", end1, "秒")

    start2 = time.time()
    model2_predict("mnist_test.jpg")
    end2 = time.time() - start2
    print("処理時間<model2>: ", end2, "秒")

    output = "<p>model1:"+str(round(end1, 3))+"秒</p><br><p>model2:"+str(round(end2, 3))+"秒</p>"
    return output

if __name__ == "__main__":
    app.run(port=5000, debug=False)

重要な部分

model1 = load_model(model_path1)
model1._make_predict_function()#<めっちゃ重要>predictの高速化
graph1 = tf.get_default_graph()

def model1_predict():
    global graph1
    with graph1.as_default():
        pred = model1.predict(***, batch_size=1, verbose=0)
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2