2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Flask + 機械学習 Webアプリ作成

Last updated at Posted at 2024-07-04

やりたいこと

バックエンドの学習としてFlaskを使って機械学習の簡単なアプリを作成したので備忘録として記載します。モデル学習は既に終わっている状態として推論工程をGUI上で実施できるアプリを想定しています。
モデル学習部分は簡易的にLightGBMで回帰モデルを作成しています。
間違いや認識違いなどあればご指摘ください。

最終的なイメージ

以下図のように3画面を作成し、図左のホーム画面で新規推論と結果画面に分岐遷移できるようにリンクボタンを配置します。

image.png

  • ホーム画面

新規推論結果管理画面のリンクを作成して各画面に遷移するようにします。

  • 新規推論用の画面

説明変数を入力する画面です。値を入力してボタンを押すことでバックエンドで推論が実行されDBに値が保存されます。3,4はバックエンドの処理なので画面上では何も変化はありません。
 1. 説明変数の値を入力
 2. 新規予測ボタンを押す
 3. (機械学習モデルに値が入力され推論が実行される)
 4. (モデルが出力した推論結果をDBに登録)
 5. ホーム画面に遷移する

  • 予測結果画面

モデルが出力した推論結果をDBから取得し画面に表示ます。

作業手順

1. ライブラリのインストール

今回使用するライブラリのインストールを行います。

pip install flask
pip install flask_sqlalchemy

※それ以外で必要なライブラリがあれば適宜pip installします。

2. フォルダ構成

最終的に以下のようなフォルダ構成となる予定です。
それぞれのファイルがどのようなものかを以降説明したいと思います。
root/
 └─ Backend/
    ├─ data/
    │  └─ BostonHousing.csv
    ├─ instance/
    │  └─ params.db
    ├─ model/
    │  └─ lgb_model.pkl
    ├─ templates/
    │  └─ create.html
    │  └─ index.html
    │  └─ results.html
    ├─ app.py
    ├─ config.py
    ├─ fit.py
    └─ requirements.txt

3. fit.pyファイルの作成、実行

まずは推論処理で使用する機械学習モデルの作成をfit.pyファイルで行います。
使用するデータセットはBostonHousing.csvというボストン住宅価格を予測するデータセットです。(https://qiita.com/yut-nagase/items/6c2bc025e7eaa7493f89)
このデータをLightGBMモデルで学習し.pklファイルとして出力しています。
最終的にはこのモデルを利用して、

  1. 説明変数をモデルに入力
  2. モデルが目的変数の予測値を出力
  3. 目的変数の数値を画面上に表示
    という処理をWebアプリ上で繰り返し行えるようにすることが目標です。
    ターミナル上でfit.pyファイルを実行することで事前にモデルを作成しておきます。
python3 fit.py

fit.pyファイルの中身は以下となっています。

fit.py
import numpy as np
import pandas as pd
import lightgbm as lgb
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
import joblib
import config

# データの読み込み
df = pd.read_csv('data/BostonHousing.csv',encoding='shift-jis')
# 目的変数
y = df[config.TARGET]
X = df[config.COLUMNS]

# データの分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

# データの整形
train = lgb.Dataset(X_train, y_train)
valid = lgb.Dataset(X_test, y_test)

#モデルパラメータの設定
params = {'metric' : 'rmse'}

# モデル学習
model = lgb.train(params, train)

#モデル予測
pred = model.predict(X_test)
print(pred)

# 学習済みモデルの保存
joblib.dump(model, "model/lgb_model.pkl", compress=True)

4. app.pyファイルの作成

次にapp.pyファイルを作成します。
こちらのファイルが今回のバックエンドのメイン処理ファイルとなります。

app.py
from flask import Flask
from flask import render_template, request, redirect
from flask_sqlalchemy import SQLAlchemy
from datetime import datetime
import pytz
import pandas as pd
import numpy as np
import joblib
import config

app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = "sqlite:///params.db"
db = SQLAlchemy(app)

# カラムの取り出し
df = pd.read_csv('data/BostonHousing.csv',encoding='shift-jis')
# 目的変数
y = df[config.TARGET]
X = df[config.COLUMNS]
col_names = X.columns

def predict(parameters):
    model = joblib.load('model/lgb_model.pkl')
    params = parameters.reshape(1,-1)
    pred = model.predict(params)
    return pred

class Post(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    value = db.Column(db.Float, nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, default=datetime.now(pytz.timezone("Asia/Tokyo")))

@app.route("/", methods=["GET", "POST"])
def index():
    # "/"の場合はindex.htmlを表示させる
    return render_template("index.html")

@app.route("/create", methods=["GET", "POST"])
def create():
    if request.method == "POST":
        # POST(value送信)の場合は推論して"/"にredirectする
        pred_list = []
        for col in col_names:
            # 入力フォームに値が入力されていることを確認
            if request.form.get(col) == "":
                return redirect("/create")
            pred_value = request.form.get(col)
            pred_list.append(pred_value)
        x = np.array(pred_list)
        pred = predict(x)

        # 予測をしたのでDBを更新する
        post = Post()  # インスタンス化
        post.value = pred
        post.created_at = datetime.now(pytz.timezone("Asia/Tokyo"))
        db.session.add(post)
        db.session.commit()
        return redirect("/")
    else:
        # GET(画面表示)の場合はcreate.htmlを表示するだけ
        return render_template("/create.html", col_names=col_names)

@app.route('/<int:id>/delete', methods=["GET"])
def delete(id):
    # 削除するidを取得してDBから削除する
    post = Post.query.get(id)
    db.session.delete(post)
    db.session.commit()
    # 削除後は"/results"にルーティングする
    return redirect("/results")

@app.route("/results")
def results():
    # 結果画面はGETしかないのでDBを更新して表示する
    posts = Post.query.all()
    return render_template("/results.html", posts=posts)

5. コードの解説

それぞれの関数,クラスについて説明します。

  • predict関数
    モデルをロードし、入力された説明変数に対して予測値を出力する関数です。
def predict(parameters):
    model = joblib.load('model/lgb_model.pkl')
    params = parameters.reshape(1,-1)
    pred = model.predict(params)
    return pred
  • Postクラス
    ここではDBの定義を行っています。
    カラムとしてid , value, create_atの3つを定義しています。
class Post(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    value = db.Column(db.Float, nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, default=datetime.now(pytz.timezone("Asia/Tokyo")))
  • index関数
    URLが/の場合に実行される関数です。index.htmlファイルを出力するように記載されています。
@app.route("/", methods=["GET", "POST"])
def index():
    # /の場合はindex.htmlを表示させる
    return render_template("index.html")
  • create関数
    入力された値(説明変数)をモデルに入れて予測結果を取得します。
    取得した値は日付時刻と一緒にDBへ追加します。
@app.route("/create", methods=["GET", "POST"])
def create():
    if request.method == "POST":
        # POST(value送信)の場合は推論して"/"にredirectする
        pred_list = []
        for col in col_names:
            # 入力フォームに値が入力されていることを確認
            if request.form.get(col) == "":
                return redirect("/create")
            pred_value = request.form.get(col)
            pred_list.append(pred_value)
        x = np.array(pred_list)
        pred = predict(x)

        # 予測をしたのでDBを更新する
        post = Post()  # インスタンス化
        post.value = pred
        post.created_at = datetime.now(pytz.timezone("Asia/Tokyo"))
        db.session.add(post)
        db.session.commit()
        return redirect("/")
    else:
        # GET(画面表示)の場合はcreate.htmlを表示するだけ
        return render_template("/create.html", col_names=col_names)
  • delete関数
    予測結果画面で過去の予測値を削除したい場合に使用します。
    予測結果画面に作成されている削除ボタンを押すことで以下関数が呼び出されます。
    渡されたidからデータを特定し削除を行い、再度results画面を表示します。
@app.route('/<int:id>/delete', methods=["GET"])
def delete(id):
    # 削除するidを取得してDBから削除する
    post = Post.query.get(id)
    db.session.delete(post)
    db.session.commit()
    # 削除後は"/results"にルーティングする
    return redirect("/results")
  • results関数
    予測結果画面です。
    DBから最新の値を取得してresults.htmlに渡すことで結果を表示しています。
@app.route("/results")
def results():
    # 結果画面はGETしかないのでDBを更新して表示する
    posts = Post.query.all()
    return render_template("/results.html", posts=posts)

6. それ以外のファイル

templatesディレクトリには画面表示するhtmlファイルを作成します。
今回はバックエンドの実装をメインの学習としたため説明は省略します。(というか自分でもあまり理解できていない&フロントの勉強は今後の課題)
以下のgitに全てのコードを上げています。
https://github.com/hfhs1213/Flask_web_app/tree/feature/flask_web_test

環境構築(再現)方法

  1. gitからファイルをcloneする (https://github.com/hfhs1213/Flask_web_app/tree/feature/flask_web_test)
  2. Dockerfileを上位階層に移動する
    root/
     ├─ Dockerfile   ←Backendディレクトリと同じ階層に移動!!
     └─ Backend/
        ├─ data/
        ├─ instance/
        ├─ model/
        ├─ templates/
        ├─ app.py
        ├─ config.py
        ├─ fit.py
        └─ requirements.txt
  3. ビルドの実行
    sudo docker build -t flask_image .
  4. コンテナ起動
    sudo docker run -it --name flask -p 5000:5000 -v /home/fumiki/projects/web_app_ml/Backend:/projects/ flask_image
  5. webサーバーの起動
    flask run
  6. 接続
    Chrome, Edgeなどでlocalhost:5000に接続する

その他のファイルについて

  • instance/params.db
    このディレクトリは以下のコマンドをターミナル上で実行することで、自動作成されます。初回のみの実行でよいので、gitからcloneする際には不要です。
    ただし、開発中にDBのテーブルを更新したり追加する場合は適宜実行する必要があります。
% python3
>>> from app import app
>>> from app import db
>>> with app.app_context():
...    db.create_all()
>>> exit()
  • config.py
    設定値を入れるように作りましたが、今回は説明変数と目的変数を定義するだけに使いました。

つまづいたこと

  • Flaskのデフォルトでは127.0.0.1のループバックアドレスとなっているため、wsl2上で構築する場合は接続できるが、コンテナ上ではこのままでは接続できなかった。設定を0.0.0.0にする必要がある。今回はDockerfile内の環境変数でFLASK_RUN_HOST=0.0.0.0を設定した。
  • pip installを実行すると仮想環境を推奨するエラーがでて進まなかった。オプションで--break-system-packagesを付けるとうまくいった。エラーが出るときと出ない時がある。。
  • return部分のredirectrender_templateの違いが判らず混在してエラーが出ていた。
    現状の理解は、redirectは表示したい(飛ばしたい)リンクを設定するため実際に表示されるものはそのリンク先の設定次第なのでとなるイメージ、render_templateは表示するhtmlを設定するので=となる。

今後やりたいこと

今後は複数コンテナでWeb3層構造を実装することでより本番環境のバックエンド、インフラを勉強したいと思います。
また、DBをPostgresqlとかにも切り替えて動かしてみたいです。

参考リンク

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?