やること
Keras
とFlask
を組み合わせて、以下のようなサービスのプロトタイプを作る
分類するもの
バスケが好きなので、NBA(アメリカのバスケリーグ)が誇る怪物達を分類してみる
データ収集
-
Selenium
を使って、各選手の画像を収集(1人250枚、合計750枚) - 自作したツールを使用して顔部分を切り抜き 64 x 64 にリサイズ
Model Summary
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 62, 62, 32) 896
_________________________________________________________________
conv2d_2 (Conv2D) (None, 60, 60, 64) 18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 30, 30, 64) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 30, 30, 64) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 57600) 0
_________________________________________________________________
dense_1 (Dense) (None, 128) 7372928
_________________________________________________________________
dropout_2 (Dropout) (None, 128) 0
_________________________________________________________________
dense_2 (Dense) (None, 3) 387
=================================================================
Training
収集した画像を水増ししてTrainingした後、モデルをmy_model.h5
として保存
水増し、Trainingの詳細は割愛させて頂く
Flaskとの連携
行っていることを簡単にまとめると、
-
model
をロード(global
変数にしているのは、分類を行う際に毎回モデルがロードされるのを防ぐため) -
POST
で送られてきた画像をPillow
で受け取りNumpy Array
に変換 -
img
をmodel
に渡しpredict
-
predict
の結果をjson
で返す - 分類の結果をページ上に反映
main.py
from flask import Flask, redirect, request, jsonify
from keras import models
import numpy as np
from PIL import Image
import io
app = Flask(__name__)
model = None
def load_model():
global model
model = models.load_model('my_model.h5')
model.summary()
print('Loaded the model')
@app.route('/')
def index():
return redirect('/static/index.html')
@app.route('/predict', methods=['POST'])
def predict():
if request.files and 'picfile' in request.files:
img = request.files['picfile'].read()
img = Image.open(io.BytesIO(img))
img.save('test.jpg')
img = np.asarray(img) / 255.
img = np.expand_dims(img, axis=0)
pred = model.predict(img)
players = [
'Lebron James',
'Stephen Curry',
'Kevin Durant',
]
confidence = str(round(max(pred[0]), 3))
pred = players[np.argmax(pred)]
data = dict(pred=pred, confidence=confidence)
return jsonify(data)
return 'Picture info did not get saved.'
@app.route('/currentimage', methods=['GET'])
def current_image():
fileob = open('test.jpg', 'rb')
data = fileob.read()
return data
if __name__ == '__main__':
load_model()
app.run(debug=False, port=5000)
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<title>Camera Project</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
<script src="https://ajax.googleapis.com/ajax/libs/jquery/1.11.3/jquery.min.js"></script>
<style>
</style>
</head>
<body>
<div class="container">
<form class="form-group">
<input type="file" accept="image/*;device=camera" class="mt-2" id="fileField" name="fileField"/>
</form>
<img id="imageID" class="img-responsive" src="/currentimage" />
<p>
<span id="prediction">-</span><br>
<span id="confidence">-</span>
</p>
</div>
<script src="app.js"></script>
</body>
</html>
app.js
$(function(){
var getImageSuccess = function(data){
$("#imageID").attr("src", "/currentimage");
};
var getImageFailure = function(data){
console.log("No image data");
};
var successResult = function(data){
$("#prediction").text("Prediction: " + data.pred)
$("#confidence").text("Confidence: " + data.confidence)
var req = {
url: "/currentimage",
method: "get"
};
var promise = $.ajax(req)
promise.then(getImageSuccess, getImageFailure);
};
var failureResult = function(data){
alert("Didn't work");
};
var fileChange = function(evt){
var fileOb = $("#fileField")[0].files[0];
var formData = new FormData();
formData.append("picfile", fileOb);
var req = {
url: "/predict",
method: "post",
processData: false,
contentType: false,
data: formData
};
var promise = $.ajax(req);
promise.then(successResult, failureResult);
};
$("#fileField").change(fileChange);
});
ソースコード
遊んでみて下さい
https://github.com/harupy/keras-flask-classifier
感想
画像分類以外にも応用できそう