LoginSignup
64
74

More than 5 years have passed since last update.

scikit-learnとflaskで簡単な機械学習✕Webアプリ

Last updated at Posted at 2017-07-30

簡単な機械学習とWebアプリ

Webサイトからの入力に対して、機械学習で作った判定モデルで結果を画面に表示する簡単なデモアプリを作ってみました。
アヤメのがく片と花びらの大きさ(縦、横)を入力すると、品種を判定して表示します。

コードはこちらにあります。
https://github.com/shibuiwilliam/mlweb

全体像

こんな感じです。

5.JPG

がく片や花びらを入力するフロントエンドと判定して結果を返すバックエンドという構成です。
スーパーシンプルです。
言語はPython3.6で、Webにはflaskとwtform、機械学習にはscikit-learnを使っています。

使い方

開発環境はCentOS7.3とPython3.6です。
ライブラリとしてflask、wtform、scikit-learn、Jupyter Notebookを入れています。
この辺は概ねAnaconda3で導入していますが、flaskとwtformはpip installしています。
コードはgit cloneしてください。


git clone https://github.com/shibuiwilliam/mlweb.git

Webサイトはポート5000で公開されます。
OSやネットワークで5000へのアクセスが可能なように設定が必要です。
(OS的にはselinuxとfirewalld)

Webサイトの実行方法は以下です。


nohup python web.py &

実行ログはnohup.outに吐かれます。
順次ログを見たければ以下を実行します。


tail -f nohup.out

機械学習

アヤメの判定モデルはscikit-learnのランダムフォレストとGridSearchCVを使っています。

scikit-learnで提供しているアヤメデータセットを取得し、ランダムフォレストのハイパーパラメータとともにGridSearchCVでモデルのパラメータを作ります。


# iris_train.ipynb


from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.externals import joblib
import numpy as np

# use Iris dataset
data = load_iris()
x = data.data
y = data.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=54321)

# run gridsearchcv on random forest classifier
forest = RandomForestClassifier()
param_grid = {
    'n_estimators'      : [5, 10, 20, 30, 50, 100, 300],
    'random_state'      : [0],
    'n_jobs'            : [1],
    'min_samples_split' : [3, 5, 10, 15, 20, 25, 30, 40, 50, 100],
    'max_depth'         : [3, 5, 10, 15, 20, 25, 30, 40, 50, 100]
}
forestGrid = GridSearchCV(forest, param_grid)
fgFit = forestGrid.fit(x_train, y_train)

作ったモデル・パラメータをランダムフォレストでfitするとモデルができあがります。
モデルはいったんpickleにして保存します。
Webサイトで使う際は、保存されたモデルを呼び出します。


# set the best params to fit random forest classifier
forest.set_params(**fgFit.best_params_)
forest.fit(x, y)

# save the model as pickle
joblib.dump(forest, './rfcParam.pkl', compress=True)

# load the model
forest = joblib.load('./rfcParam.pkl')

# predict
t = np.array([5.1,  3.5,  1.4,  0.2])
t = t.reshape(1,-1)
print(forest.predict(t))

Webサイト

Webはflaskとwtformで作っています。
まずはWebで使う関数を定義します(web.py)。


# web.py

import numpy as np

# 入力をcsvでログとして保存します。
def insert_csv(data):
    import csv
    import uuid
    tuid = str(uuid.uuid1())
    with open("./logs/"+tuid+".csv", "a") as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerow(["sepalLength","sepalWidth","petalLength","petalWidth"])
        writer.writerow(data)
    return tuid

# scikit-learnで作ったモデルを使って判定します。
def predictIris(params):
    from sklearn.externals import joblib
    # load the model
    forest = joblib.load('./rfcParam.pkl')
    # predict
    params = params.reshape(1,-1)
    pred = forest.predict(params)
    return pred

# 判定は0,1,2で出力されるので、アヤメの品種名に変換します。
def getIrisName(irisId):
    if irisId == 0: return "Iris Setosa"
    elif irisId == 1: return "Iris Versicolour"
    elif irisId == 2: return "Iris Virginica"
    else: return "Error"

WebのHTMLとflaskを使ったPythonコードは以下になります。

まずは入力フォーム(templates/irisPred.html)です。


<!doctype html>
<html>
   <body>

      <h2 style = "text-align: center;">Enter Iris Params</h2>

      {% for message in form.SepalLength.errors %}
         <div>{{ message }}</div>
      {% endfor %}

      {% for message in form.SepalWidth.errors %}
         <div>{{ message }}</div>
      {% endfor %}

      {% for message in form.PetalLength.errors %}
         <div>{{ message }}</div>
      {% endfor %}

      {% for message in form.PetalWidth.errors %}
         <div>{{ message }}</div>
      {% endfor %}

      <form action = "" method = post>
         <fieldset>
            <legend>Enter parameters here.</legend>

            <div style = font-size:20px; font-weight:bold; margin-left:150px;>
               {{ form.SepalLength.label }}<br>
               {{ form.SepalLength }}
               <br>
               {{ form.SepalWidth.label }}<br>
               {{ form.SepalWidth }}
               <br>
               {{ form.PetalLength.label }}<br>
               {{ form.PetalLength }}
               <br>
               {{ form.PetalWidth.label }}<br>
               {{ form.PetalWidth }}
               <br>
               {{ form.submit }}
            </div>

         </fieldset>
      </form>

   </body>
</html>

判定結果を表示する画面(templates/success.html)です。


<!doctype html>
<title>Hello from Iris</title>
{% if irisName %}
  <h1>It is {{ irisName }}</h1>
{% else %}
  <h1>Hello from Iris</h1>
{% endif %}

Pythonのフォーム定義(web.py)です。


# web.py

from flask import Flask, render_template, request, flash
from wtforms import Form, FloatField, SubmitField, validators, ValidationError

# App config.
DEBUG = True
app = Flask(__name__)
app.config.from_object(__name__)
app.config['SECRET_KEY'] = 'fda0e618-685c-11e7-bb40-fa163eb65161'

class IrisForm(Form):
    SepalLength = FloatField("Sepal Length in cm",
                     [validators.InputRequired("all parameters are required!"),
                     validators.NumberRange(min=0, max=10)])
    SepalWidth = FloatField("Sepal Width in cm",
                     [validators.InputRequired("all parameters are required!"),
                     validators.NumberRange(min=0, max=10)])
    PetalLength = FloatField("Petal Length in cm",
                     [validators.InputRequired("all parameters are required!"),
                     validators.NumberRange(min=0, max=10)])
    PetalWidth = FloatField("Petal Width in cm",
                     [validators.InputRequired("all parameters are required!"),
                     validators.NumberRange(min=0, max=10)])
    submit = SubmitField("Try")

@app.route('/irisPred', methods = ['GET', 'POST'])
def irisPred():
    form = IrisForm(request.form)
    if request.method == 'POST':
        if form.validate() == False:
            flash("You need all parameters")
            return render_template('irisPred.html', form = form)
        else:            
            SepalLength = float(request.form["SepalLength"])            
            SepalWidth = float(request.form["SepalWidth"])            
            PetalLength = float(request.form["PetalLength"])            
            PetalWidth = float(request.form["PetalWidth"])
            params = np.array([SepalLength, SepalWidth, PetalLength, PetalWidth])
            print(params)
            insert_csv(params)
            pred = predictIris(params)
            irisName = getIrisName(pred)

            return render_template('success.html', irisName=irisName)
    elif request.method == 'GET':
        return render_template('irisPred.html', form = form)

if __name__ == "__main__":
    app.debug = True
    app.run(host='0.0.0.0')

実行イメージ

画面の入力フォームに適当な数字(0~10の間)を入力してTryします。

1.JPG

画面に判定されたアヤメの品種名が表示されます。

2.JPG

入力エラーのチェックします。

4.JPG

64
74
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
64
74