簡単な機械学習とWebアプリ
Webサイトからの入力に対して、機械学習で作った判定モデルで結果を画面に表示する簡単なデモアプリを作ってみました。
アヤメのがく片と花びらの大きさ(縦、横)を入力すると、品種を判定して表示します。
コードはこちらにあります。
https://github.com/shibuiwilliam/mlweb
全体像
こんな感じです。
がく片や花びらを入力するフロントエンドと判定して結果を返すバックエンドという構成です。
スーパーシンプルです。
言語はPython3.6で、Webにはflaskとwtform、機械学習にはscikit-learnを使っています。
使い方
開発環境はCentOS7.3とPython3.6です。
ライブラリとしてflask、wtform、scikit-learn、Jupyter Notebookを入れています。
この辺は概ねAnaconda3で導入していますが、flaskとwtformはpip installしています。
コードはgit cloneしてください。
git clone https://github.com/shibuiwilliam/mlweb.git
Webサイトはポート5000で公開されます。
OSやネットワークで5000へのアクセスが可能なように設定が必要です。
(OS的にはselinuxとfirewalld)
Webサイトの実行方法は以下です。
nohup python web.py &
実行ログはnohup.outに吐かれます。
順次ログを見たければ以下を実行します。
tail -f nohup.out
機械学習
アヤメの判定モデルはscikit-learnのランダムフォレストとGridSearchCVを使っています。
scikit-learnで提供しているアヤメデータセットを取得し、ランダムフォレストのハイパーパラメータとともにGridSearchCVでモデルのパラメータを作ります。
# iris_train.ipynb
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.externals import joblib
import numpy as np
# use Iris dataset
data = load_iris()
x = data.data
y = data.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=54321)
# run gridsearchcv on random forest classifier
forest = RandomForestClassifier()
param_grid = {
'n_estimators' : [5, 10, 20, 30, 50, 100, 300],
'random_state' : [0],
'n_jobs' : [1],
'min_samples_split' : [3, 5, 10, 15, 20, 25, 30, 40, 50, 100],
'max_depth' : [3, 5, 10, 15, 20, 25, 30, 40, 50, 100]
}
forestGrid = GridSearchCV(forest, param_grid)
fgFit = forestGrid.fit(x_train, y_train)
作ったモデル・パラメータをランダムフォレストでfitするとモデルができあがります。
モデルはいったんpickleにして保存します。
Webサイトで使う際は、保存されたモデルを呼び出します。
# set the best params to fit random forest classifier
forest.set_params(**fgFit.best_params_)
forest.fit(x, y)
# save the model as pickle
joblib.dump(forest, './rfcParam.pkl', compress=True)
# load the model
forest = joblib.load('./rfcParam.pkl')
# predict
t = np.array([5.1, 3.5, 1.4, 0.2])
t = t.reshape(1,-1)
print(forest.predict(t))
Webサイト
Webはflaskとwtformで作っています。
まずはWebで使う関数を定義します(web.py)。
# web.py
import numpy as np
# 入力をcsvでログとして保存します。
def insert_csv(data):
import csv
import uuid
tuid = str(uuid.uuid1())
with open("./logs/"+tuid+".csv", "a") as f:
writer = csv.writer(f, lineterminator='\n')
writer.writerow(["sepalLength","sepalWidth","petalLength","petalWidth"])
writer.writerow(data)
return tuid
# scikit-learnで作ったモデルを使って判定します。
def predictIris(params):
from sklearn.externals import joblib
# load the model
forest = joblib.load('./rfcParam.pkl')
# predict
params = params.reshape(1,-1)
pred = forest.predict(params)
return pred
# 判定は0,1,2で出力されるので、アヤメの品種名に変換します。
def getIrisName(irisId):
if irisId == 0: return "Iris Setosa"
elif irisId == 1: return "Iris Versicolour"
elif irisId == 2: return "Iris Virginica"
else: return "Error"
WebのHTMLとflaskを使ったPythonコードは以下になります。
まずは入力フォーム(templates/irisPred.html)です。
<!doctype html>
<html>
<body>
<h2 style = "text-align: center;">Enter Iris Params</h2>
{% for message in form.SepalLength.errors %}
<div>{{ message }}</div>
{% endfor %}
{% for message in form.SepalWidth.errors %}
<div>{{ message }}</div>
{% endfor %}
{% for message in form.PetalLength.errors %}
<div>{{ message }}</div>
{% endfor %}
{% for message in form.PetalWidth.errors %}
<div>{{ message }}</div>
{% endfor %}
<form action = "" method = post>
<fieldset>
<legend>Enter parameters here.</legend>
<div style = font-size:20px; font-weight:bold; margin-left:150px;>
{{ form.SepalLength.label }}<br>
{{ form.SepalLength }}
<br>
{{ form.SepalWidth.label }}<br>
{{ form.SepalWidth }}
<br>
{{ form.PetalLength.label }}<br>
{{ form.PetalLength }}
<br>
{{ form.PetalWidth.label }}<br>
{{ form.PetalWidth }}
<br>
{{ form.submit }}
</div>
</fieldset>
</form>
</body>
</html>
判定結果を表示する画面(templates/success.html)です。
<!doctype html>
<title>Hello from Iris</title>
{% if irisName %}
<h1>It is {{ irisName }}</h1>
{% else %}
<h1>Hello from Iris</h1>
{% endif %}
Pythonのフォーム定義(web.py)です。
# web.py
from flask import Flask, render_template, request, flash
from wtforms import Form, FloatField, SubmitField, validators, ValidationError
# App config.
DEBUG = True
app = Flask(__name__)
app.config.from_object(__name__)
app.config['SECRET_KEY'] = 'fda0e618-685c-11e7-bb40-fa163eb65161'
class IrisForm(Form):
SepalLength = FloatField("Sepal Length in cm",
[validators.InputRequired("all parameters are required!"),
validators.NumberRange(min=0, max=10)])
SepalWidth = FloatField("Sepal Width in cm",
[validators.InputRequired("all parameters are required!"),
validators.NumberRange(min=0, max=10)])
PetalLength = FloatField("Petal Length in cm",
[validators.InputRequired("all parameters are required!"),
validators.NumberRange(min=0, max=10)])
PetalWidth = FloatField("Petal Width in cm",
[validators.InputRequired("all parameters are required!"),
validators.NumberRange(min=0, max=10)])
submit = SubmitField("Try")
@app.route('/irisPred', methods = ['GET', 'POST'])
def irisPred():
form = IrisForm(request.form)
if request.method == 'POST':
if form.validate() == False:
flash("You need all parameters")
return render_template('irisPred.html', form = form)
else:
SepalLength = float(request.form["SepalLength"])
SepalWidth = float(request.form["SepalWidth"])
PetalLength = float(request.form["PetalLength"])
PetalWidth = float(request.form["PetalWidth"])
params = np.array([SepalLength, SepalWidth, PetalLength, PetalWidth])
print(params)
insert_csv(params)
pred = predictIris(params)
irisName = getIrisName(pred)
return render_template('success.html', irisName=irisName)
elif request.method == 'GET':
return render_template('irisPred.html', form = form)
if __name__ == "__main__":
app.debug = True
app.run(host='0.0.0.0')
実行イメージ
画面の入力フォームに適当な数字(0~10の間)を入力してTryします。
画面に判定されたアヤメの品種名が表示されます。
入力エラーのチェックします。