Python
Flask
DeepLearning
Keras

Flask & Keras で画像分類サービスのプロトタイプ作成

やること

KerasFlaskを組み合わせて、以下のようなサービスのプロトタイプを作る

花の名前を人工知能が教えてくれるサービス

分類するもの

バスケが好きなので、NBA(アメリカのバスケリーグ)が誇る怪物達を分類してみる

  1. Lebron James
    10.jpg

  2. Kevin Durant
    11.jpg

  3. Stephen Curry
    17.jpg

データ収集

  • Seleniumを使って、各選手の画像を収集(1人250枚、合計750枚)
  • 自作したツールを使用して顔部分を切り抜き 64 x 64 にリサイズ

07.jpg

16.jpg

09.jpg

Model Summary

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_1 (Conv2D)            (None, 62, 62, 32)        896
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 60, 60, 64)        18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 30, 30, 64)        0
_________________________________________________________________
dropout_1 (Dropout)          (None, 30, 30, 64)        0
_________________________________________________________________
flatten_1 (Flatten)          (None, 57600)             0
_________________________________________________________________
dense_1 (Dense)              (None, 128)               7372928
_________________________________________________________________
dropout_2 (Dropout)          (None, 128)               0
_________________________________________________________________
dense_2 (Dense)              (None, 3)                 387
=================================================================

Training

収集した画像を水増ししてTrainingした後、モデルをmy_model.h5として保存
水増し、Trainingの詳細は割愛させて頂く

Flaskとの連携

行っていることを簡単にまとめると、
1. modelをロード(global変数にしているのは、分類を行う際に毎回モデルがロードされるのを防ぐため)
2. POSTで送られてきた画像をPillowで受け取りNumpy Arrayに変換
3. imgmodelに渡しpredict
4. predictの結果をjsonで返す
5. 分類の結果をページ上に反映

main.py
from flask import Flask, redirect, request, jsonify
from keras import models
import numpy as np
from PIL import Image
import io


app = Flask(__name__)
model = None


def load_model():
    global model
    model = models.load_model('my_model.h5')
    model.summary()
    print('Loaded the model')


@app.route('/')
def index():
    return redirect('/static/index.html')


@app.route('/predict', methods=['POST'])
def predict():
    if request.files and 'picfile' in request.files:
        img = request.files['picfile'].read()
        img = Image.open(io.BytesIO(img))
        img.save('test.jpg')
        img = np.asarray(img) / 255.
        img = np.expand_dims(img, axis=0)
        pred = model.predict(img)

        players = [
            'Lebron James',
            'Stephen Curry',
            'Kevin Durant',
        ]

        confidence = str(round(max(pred[0]), 3))
        pred = players[np.argmax(pred)]

        data = dict(pred=pred, confidence=confidence)
        return jsonify(data)

    return 'Picture info did not get saved.'


@app.route('/currentimage', methods=['GET'])
def current_image():
    fileob = open('test.jpg', 'rb')
    data = fileob.read()
    return data


if __name__ == '__main__':
    load_model()
    app.run(debug=False, port=5000)

index.html
<!DOCTYPE html>
<html lang="en">
<head>
    <title>Camera Project</title>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
    <script src="https://ajax.googleapis.com/ajax/libs/jquery/1.11.3/jquery.min.js"></script>
    <style>
    </style>
</head>
<body>

<div class="container">
    <form class="form-group">
        <input type="file" accept="image/*;device=camera" class="mt-2" id="fileField" name="fileField"/>
    </form>
    <img id="imageID" class="img-responsive" src="/currentimage" />
    <p>
      <span id="prediction">-</span><br>
      <span id="confidence">-</span>
    </p>
</div>
<script src="app.js"></script>
</body>
</html>
app.js
$(function(){

  var getImageSuccess = function(data){
    $("#imageID").attr("src", "/currentimage");

  };

  var getImageFailure = function(data){
    console.log("No image data");
  };

  var successResult = function(data){
    $("#prediction").text("Prediction: " + data.pred)
    $("#confidence").text("Confidence: " + data.confidence)
    var req = {
      url: "/currentimage",
      method: "get"
    };

    var promise = $.ajax(req)
    promise.then(getImageSuccess, getImageFailure);
  };
  var failureResult = function(data){
    alert("Didn't work");
  };

  var fileChange = function(evt){
    var fileOb = $("#fileField")[0].files[0];
    var formData = new FormData();
    formData.append("picfile", fileOb);
    var req = {
      url: "/predict",
      method: "post",
      processData: false,
      contentType: false,
      data: formData
    };

    var promise = $.ajax(req);
    promise.then(successResult, failureResult);
  };

  $("#fileField").change(fileChange);
});

デモ

demo.gif

ソースコード

遊んでみて下さい
https://github.com/harupy/keras-flask-classifier

感想

画像分類以外にも応用できそう