Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

[Flask入門] 簡単な機械学習APIの作成

More than 1 year has passed since last update.

本記事で行うこと

  • wikipediaの文章を入力データ、その文章に対応するカテゴリーを正解データとして、文章を2値分類(または学習)するAPIを作成する

対象読者

  • Flaskを使って何か作ってみたい方
  • 機械学習を用いたアプリケーションを作ってみたい方

使用言語

  • Python 3.6.3

アプリケーションのディレクトリ構成

今回使用するファイル及びディレクトリの構成は以下のようになっています。
ファイルの内容は順を追って説明します。

tree
flask
├── config
│   └── requirements.txt
├── data
│   ├── wikipedia-train.txt
│   └── wikipedia-test.txt
├── model
│   └── wikipedia-category_logistic.pkl.gz
├── lib
│   ├── preprocess.py
│   └── trainer.py
├── app.py
└── wikipedia_vectorizer.pkl.gz

開発環境の構築

必要なライブラリーをまとめてインストールする。

bash
# versionの確認
$pyenv versions
  system
* 3.6.3 (set by /Users/../flask/.python-version)
# venv
$. env/bin/activate
(env) $ pip install -r config/requirements.txt

config/requirements.txt :今回使うライブラリーをまとめたファイル

astroid==2.0.4
Click==7.0
flake8==3.6.0
Flask==1.0.2
isort==4.3.4
itsdangerous==1.1.0
Jinja2==2.10
joblib==0.12.5
lazy-object-proxy==1.3.1
MarkupSafe==1.0
mccabe==0.6.1
mecab-python3==0.7
numpy==1.14.5
pandas==0.22.0
pycodestyle==2.4.0
pyflakes==2.0.0
pylint==2.1.1
python-dateutil==2.7.4
pytz==2018.6
scikit-learn==0.20.0
scipy==1.1.0
six==1.11.0
sklearn==0.0
typed-ast==1.1.0
Werkzeug==0.14.1
wrapt==1.10.11

データの前処理~学習を行うファイル

data/wikipedia-train.txt : 学習時に使用する入力データと正解データ
data/wikipedia-test.txt: 学習したモデルの検証用データ

以下、学習用/検証用データはwikipediaをスクレイピングし、categoryとtextデータを取得する。

wikipedia-train.txt
category    text
sports  ナイキ(Nike, Inc.)は、アメリカ合衆国・オレゴン州に本社を置くスニーカーやスポーツウェアなどスポーツ関連商品を扱う世界的企業
sports  コンバース(Converse)は、アメリカのシューズ製造販売会社。オールスター、ジャックパーセルなどのスニーカーなどで知られる。
fashion ネイルケア(nail care)とは、ヒトの爪とその周辺の手入れのことを言う。一般的な爪を切る行為から、美容や身だしなみ、さらに医療行為まで、様々な目的があり、ネイルケアは一つの確立された分野となっている。
...

lib/preprocess.py : 上記、入力データの文章に対してわかち書きを行うメソッド、tf-idfによって単語の分散表現を取得するメソッドを内包した前処理を行うクラス

lib/preprocess.py
# pylint: disable=missing-docstring

import joblib
import MeCab
from sklearn.feature_extraction.text import TfidfVectorizer

MODEL_NAME = "wikipedia_vectorizer.pkl.gz"

class PreProcess:
    def __init__(self):
        self.mecab = MeCab.Tagger("-O wakati")
        self.vectorizer = TfidfVectorizer(token_pattern=r"(?u)\b\w+\b")

    def get_tokenized(self, df):
        """
        わかち書きを行う
        """
        text_tokenized = []
        for text in df["text"]:
            text_tokenized.append(self.mecab.parse(text))
        df["text_tokenized"] = text_tokenized
        return df

    def get_tfidf(self, train, test):
        """
        tf-idfを用いて単語の分散表現を取得する
        """
        train_x = self.vectorizer.fit_transform(train["text_tokenized"])
        test_x = self.vectorizer.transform(test["text_tokenized"])

        joblib.dump(self.vectorizer, MODEL_NAME)

        return train_x, test_x

lib/trainer.py : 実際に学習を行いモデルをmodelディレクトリに格納し、検証用データを使って正解率を返すクラス

今回ロジスティック回帰を用いているが、事前にランダムフォレストやSVM、ニューラルネットワークを使って正解率を求めた。その結果、最もロジスティック回帰が正解率が高かったためこの手法を用いた。

lib/trainer.py
# pylint: disable=missing-docstring
# pylint: disable=too-few-public-methods

import os
import numpy as np
import pandas as pd
import joblib

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score
from lib.preprocess import PreProcess


BASE_DIR = os.path.dirname(os.path.abspath(__file__))

TRAIN_DATA_PATH = os.path.join(BASE_DIR, "../data/wikipedia-train.txt")
TEST_DATA_PATH = os.path.join(BASE_DIR, "../data/wikipedia-test.txt")
MODEL_PATH = os.path.join(BASE_DIR, "../model")
MODEL_NAME = "wikipedia-category_logistic.pkl.gz"


class Trainer:

    def __init__(self):
        self.raw_train = pd.read_csv(TRAIN_DATA_PATH, sep="\t")
        self.raw_test = pd.read_csv(TEST_DATA_PATH, sep="\t")
        self.preprocess = PreProcess()

    @staticmethod
    def grid_search(train_x, train_t):
        params = np.arange(1, 100, 1)
        print("start grid search!")
        clf_lr = GridSearchCV(LogisticRegression(random_state=0, verbose=0),
                              param_grid={"C": params})

        clf_lr.fit(train_x.toarray(), train_t)
        clf_lr_bst = clf_lr.best_estimator_
        print("best_score:{}".format(clf_lr.best_score_))
        return clf_lr_bst

    def get_training_data(self, raw_train, raw_test):
        df_train = self.preprocess.get_tokenized(raw_train)
        df_test = self.preprocess.get_tokenized(raw_test)

        train_x, test_x = self.preprocess.get_tfidf(df_train,
                                                    df_test)
        train_t = self.raw_train["category"]
        test_t = self.raw_test["category"]

        return train_x, test_x, train_t, test_t

    def train(self):
        train_x, test_x, \
            train_t, test_t = self.get_training_data(self.raw_train, self.raw_test)

        clf_lr_bst = self.grid_search(train_x, train_t)

        print("start training!")
        clf_lr_bst.fit(train_x, train_t)
        y_test_pred = clf_lr_bst.predict(test_x.toarray())

        accuracy = accuracy_score(test_t, y_test_pred)

        print("accuracy:{}".format(accuracy))

        joblib.dump(clf_lr_bst, f"{MODEL_PATH}/{MODEL_NAME}")
        print("saved model")

        return accuracy

model/wikipedia-category_logistic.pkl.gz : 学習した際に生成されたモデル
wikipedia_vectorizer.pkl.gz : tf-idfによって生成された単語の分散表現

Flaskを使ったAPI実装用のファイル

app.py : 入力データに対して分類結果を返す、または学習を実行するAPI実装ファイル

エンドポイントには、/classify/trainを用意しています。

  • /trainにアクセスすると、lib/trainer.pyTrainer().train()メソッドを実行し学習を行い、正解率を返す。

  • /classifyにアクセスすると、学習時に生成したモデルを読み込み入力データに対する分類結果を返す。(入力データの与え方は次項に記述する)

app.py
import os
import subprocess
import traceback

import joblib
import MeCab

from flask import Flask, request, abort, jsonify
from lib.trainer import Trainer

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_NAME = "wikipedia-category_logistic.pkl.gz"
MODEL_PATH = os.path.join(BASE_DIR, f"model/{MODEL_NAME}")

app = Flask(__name__)

vectorizer = joblib.load("wikipedia_vectorizer.pkl.gz")


@app.route("/train")
def train():
    accuracy = Trainer().train()

    output_json = jsonify({
        "code": 200,
        "message": None,
        "result": accuracy
    })
    return output_json


@app.route("/predict", methods=["POST"])
def process_request():

    print("モデルを読み込みます...")
    if os.path.exists(MODEL_PATH):
        clf = joblib.load(MODEL_PATH)
    else:
        return "model not trained. call `/train` endpoint"

    try:
        mecab_dicdir = subprocess.run(
            "mecab-config --dicdir",
            shell=True,
            stdout=subprocess.PIPE,
            universal_newlines=True
        ).stdout.rstrip()
        mecab = MeCab.Tagger("-b 5242880 -Owakati --dicdir={}".format(mecab_dicdir + "/mecab-ipadic-neologd"))
        print("mecab-ipadic-neologd を使用します")
    except:
        mecab = MeCab.Tagger("-b 5242880 -Owakati")
        print("デフォルトの辞書を使用します")

    if not request.is_json:
        abort(400, {"message": "Input Content-Type is not application/json."})

    data = request.get_json()
    if "text" not in data:
        abort(400, {"message": "text is not present in request parameter."})

    text = data["text"]
    if not isinstance(text, str):
        abort(400, {"message": "text is not string."})

    try:
        text_tokenized = mecab.parse(text)
        target_data = vectorizer.transform([text_tokenized])
        cls = clf.predict(target_data)[0]
    except Exception as e:
        abort(500, {"message": "prediction error occurred: {}".format(e)})

    output_json = jsonify({
        "code": 200,
        "message": None,
        "result": cls,
    })

    return output_json


@app.errorhandler(400)
def bad_request_handler(error):
    output_json = jsonify({
        "code": error.code,
        "message": error.description["message"],
    })
    return output_json, error.code


@app.errorhandler(404)
def not_found_handler(error):
    output_json = jsonify({
        "code": error.code,
        "message": "Requested resource is not found.",
    })
    return output_json, error.code


@app.errorhandler(Exception)
def internal_server_error_handler(e):
    print(traceback.format_exc())

    output_json = jsonify({
        "code": 500,
        "message": traceback.format_exc(),
    })
    return output_json, 500


if __name__ == '__main__':
    app.run(host="localhost")

APIサーバーの起動

以下のコマンドでAPIサーバーを起動する。

bash
(env) $ python app.py 
 * Serving Flask app "app" (lazy loading)
 * Environment: production
   WARNING: Do not use the development server in a production environment.
   Use a production WSGI server instead.
 * Debug mode: off
 * Running on http://localhost:5000/ (Press CTRL+C to quit)

別のターミナルを開き、curlコマンドを用いて各種エンドポイントにアクセスする。

bash
# 学習を行い/model以下にモデルを生成し、正解率を返す
(env) $ curl http://localhost:5000/train
{"code":200,"message":null,"result":0.93}

# 入力データに対して分類結果を返す
(env) $ curl -X POST -H "Content-Type: application/json" --data '{ "text": "
テキスト" }'  http://localhost:5000/predict
{"code":200,"message":null,"result":"sports"}

--data '{ "text": "テキスト" }'"テキスト"箇所に文章を入力する。
上記のように、ステータスコードと分類結果が返ってくることを確認できれば終わりです。

時間があればそのうちAWS API Gatewayを使ったサーバーレスなAPIを作ってみようと思います。

kurakura0916
GCP, AWSを使った機械学習のインフラについて書いています。
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away