やること
Keras
とFlask
を組み合わせて、以下のようなサービスのプロトタイプを作る
分類するもの
バスケが好きなので、NBA(アメリカのバスケリーグ)が誇る怪物達を分類してみる
データ収集
-
Selenium
を使って、各選手の画像を収集(1人250枚、合計750枚) - 自作したツールを使用して顔部分を切り抜き 64 x 64 にリサイズ
Model Summary
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 62, 62, 32) 896
_________________________________________________________________
conv2d_2 (Conv2D) (None, 60, 60, 64) 18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 30, 30, 64) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 30, 30, 64) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 57600) 0
_________________________________________________________________
dense_1 (Dense) (None, 128) 7372928
_________________________________________________________________
dropout_2 (Dropout) (None, 128) 0
_________________________________________________________________
dense_2 (Dense) (None, 3) 387
=================================================================
Training
収集した画像を水増ししてTrainingした後、モデルをmy_model.h5
として保存
水増し、Trainingの詳細は割愛させて頂く
Flaskとの連携
行っていることを簡単にまとめると、
1. model
をロード(global
変数にしているのは、分類を行う際に毎回モデルがロードされるのを防ぐため)
2. POST
で送られてきた画像をPillow
で受け取りNumpy Array
に変換
3. img
をmodel
に渡しpredict
4. predict
の結果をjson
で返す
5. 分類の結果をページ上に反映
main.py
from flask import Flask, redirect, request, jsonify
from keras import models
import numpy as np
from PIL import Image
import io
app = Flask(__name__)
model = None
def load_model():
global model
model = models.load_model('my_model.h5')
model.summary()
print('Loaded the model')
@app.route('/')
def index():
return redirect('/static/index.html')
@app.route('/predict', methods=['POST'])
def predict():
if request.files and 'picfile' in request.files:
img = request.files['picfile'].read()
img = Image.open(io.BytesIO(img))
img.save('test.jpg')
img = np.asarray(img) / 255.
img = np.expand_dims(img, axis=0)
pred = model.predict(img)
players = [
'Lebron James',
'Stephen Curry',
'Kevin Durant',
]
confidence = str(round(max(pred[0]), 3))
pred = players[np.argmax(pred)]
data = dict(pred=pred, confidence=confidence)
return jsonify(data)
return 'Picture info did not get saved.'
@app.route('/currentimage', methods=['GET'])
def current_image():
fileob = open('test.jpg', 'rb')
data = fileob.read()
return data
if __name__ == '__main__':
load_model()
app.run(debug=False, port=5000)
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<title>Camera Project</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
<script src="https://ajax.googleapis.com/ajax/libs/jquery/1.11.3/jquery.min.js"></script>
<style>
</style>
</head>
<body>
<div class="container">
<form class="form-group">
<input type="file" accept="image/*;device=camera" class="mt-2" id="fileField" name="fileField"/>
</form>
<img id="imageID" class="img-responsive" src="/currentimage" />
<p>
<span id="prediction">-</span><br>
<span id="confidence">-</span>
</p>
</div>
<script src="app.js"></script>
</body>
</html>
app.js
$(function(){
var getImageSuccess = function(data){
$("#imageID").attr("src", "/currentimage");
};
var getImageFailure = function(data){
console.log("No image data");
};
var successResult = function(data){
$("#prediction").text("Prediction: " + data.pred)
$("#confidence").text("Confidence: " + data.confidence)
var req = {
url: "/currentimage",
method: "get"
};
var promise = $.ajax(req)
promise.then(getImageSuccess, getImageFailure);
};
var failureResult = function(data){
alert("Didn't work");
};
var fileChange = function(evt){
var fileOb = $("#fileField")[0].files[0];
var formData = new FormData();
formData.append("picfile", fileOb);
var req = {
url: "/predict",
method: "post",
processData: false,
contentType: false,
data: formData
};
var promise = $.ajax(req);
promise.then(successResult, failureResult);
};
$("#fileField").change(fileChange);
});
デモ
ソースコード
遊んでみて下さい
https://github.com/harupy/keras-flask-classifier
感想
画像分類以外にも応用できそう