1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch + Flask + uWSGIによるシンプル機械学習APIの開発 (アプリ実装編)

Last updated at Posted at 2020-12-10

はじめに

機械学習モデルを作るところまでのチュートリアルは世の中にあふれていますが、作ったモデルを実際どう使うの?という部分の情報になると途端に少なくなりますよね。

ここでは、自身の備忘の意味も込めて、「手書き文字認識」タスクをサンプルに、モデルの構築からデプロイまでの手順の一例を一気通貫に紹介していきます。

実際のサービスに乗せるには心もとない点も多いですが、全体的なイメージをつかむきっかけになればうれしいです。

長くなったため前後編に分けており、この記事は2の部分を対象にした後編です。

  1. PyTorchによるモデルの学習 (モデル学習編)
  2. Flask + uWSGIによるAPIの実装 (アプリ実装編)

なお、本記事におけるFlask関連のコードはPyTorchの公式チュートリアルをベースにしています。

間違いやより良い方法など、お気づきの点があればぜひお気軽にコメントください。

環境構築

前回の記事の環境を引き続き利用します。

記事のコードをすべて実行した段階では、プロジェクトのディレクトリは以下の内容になっているかと思います。

Digit_Recognizer    
    ├── mnist_net.pth   
    ├── Pipfile  
    ├── Pipfile.lock  
    └── sample_images

上記をスタートとして、本記事では最終的に以下の構成となります。

Digit_Recognizer  
    ├── app  
    │    ├── main.py  # メインスクリプト  
    │    └── net.py  # ネットワーク定義スクリプト   
    ├── mnist_net.pth  
    ├── uwsgi.ini  # uWSGIの設定ファイル  
    ├── Pipfile  
    ├── Pipfile.lock  
    ├── run.py  # 開発時のアプリ実行用スクリプト  
    └── sample_images

コマンドラインからアプリ用のディレクトリを作成し、必要なパッケージをインストールしておきます。

cd ./Digit_Recognizer/
mkdir app
pipenv install flask uwsgi

アプリの実装

さっそくアプリケーションを実装していきます。

今回目指すのは、手書き数字の画像を受け取り、そこに書いてある数字を予測した上で、その結果を返すAPIです。

なお、エンドユーザ向けではなく、システム内での利用 (あるシステムからAPIサーバにリクエストを投げ、返ってきた結果を後続の処理に利用するような構成) を想定し、アップロード画面のようなインターフェースは作らず、コマンドラインからの実行を前提とします。

サンプルアプリの実装

まずはじめにサンプルスクリプトでFlaskの挙動を確認しておきます。
次のflask_sample.pyを作成し、適当なディレクトリに保存します。

flask_sample.py
from flask import Flask

# 1)Flaskインスタンスの作成
app = Flask(__name__)

# 2)相対URL'/'にアクセスがあったときの処理を定義
@app.route('/')
def hello():
    return "Hello World!"

if __name__ == '__main__':
    # 3)アプリの起動
    app.run()

基本要素は、1)Flaskインスタンス (アプリの本体) の作成、2)URLごとの処理の定義、3)アプリの起動の3つです。
後ほど2でモデルによる予測を実装することになります。

では、作成したファイルをコマンドラインから実行してみます。

python flask_sample.py
出力
 * Serving Flask app "flask_sample" (lazy loading)
 * Environment: production
   WARNING: This is a development server. Do not use it in a production deployment.
   Use a production WSGI server instead.
 * Debug mode: off
 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)

上記のログが表示されていれば、アプリが正しく起動しています。
デフォルトでは、hostはlocalhost (127.0.0.1)、portは5000が指定されているため、そこを起点とした相対URL/ (ルート)、つまりhttp://localhost:5000/にアクセスすると、指定した処理の結果が返ってくるはずです。
コマンドラインからはcurlコマンドで結果を取得できます。

curl http://localhost:5000/
出力
Hello World!

手書き文字認識アプリの実装

Flaskの雰囲気をつかめたところで、ここからは実際に利用するプログラムを書いていきます。
まずは、ネットワークを定義します。これは学習時に利用したclassをそのままコピペすればOKです。

app/net.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)

        return x

次に、メインのスクリプトを作ります。
基本の構成はflask_sample.pyとほとんど同じです。

主な変更点は、2にモデルによる予測処理を記載し、相対URLを機能に対応した/predictに指定した点です。

また、メソッドをPOSTに指定することで、リスクエスト時にファイルを受け取ることができます。
Flaskにはファイルを受け取るためのrequestオブジェクトが用意されており、POSTされたファイルをrequest.filesで取得することができます。

なお、3のアプリの起動は別のスクリプトから行うため、ここでは省略しています。

GETとPOSTの違いについて

app/main.py
import io

import torch
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

from app.net import Net


# 1)Flaskインスタンスの作成
app = Flask(__name__)

# 学習済みモデルの準備 (2で利用)
model = Net()
model.load_state_dict(torch.load('./mnist_net.pth'))
model.eval()


# 前処理用関数の定義 (2で利用)
def transform_image(image_bytes):
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,))])  # 学習時と同様に正規化
    image = Image.open(io.BytesIO(image_bytes))

    return transform(image).unsqueeze(0)


# 予測用関数の定義 (2で利用)
def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)  # スコアが最大の数字を予測値として取得

    return y_hat.item()


# 2)画像受信時の処理の定義
@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        prediction = get_prediction(image_bytes=img_bytes)  # 数字の予測

        return jsonify({'prediction': prediction})  # JSON形式に変換

ここで一旦動作を確認します。
先ほど省略した3のアプリ起動用にrun.pyを作成し、実行します。
ちなみに、hostとportは明示的に指定することも可能で、以下ではportに8080を指定しています。

run.py
from app.main import app

app.run(host='localhost', port=8080, debug=True)
python run.py

ここではPythonのRequestsモジュールを使ってリクエストを投げてみます。
レスポンスにjsonメソッドを適用することで、JSON形式の情報をPythonの辞書型として取得することができます。

import requests

resp = requests.post("http://localhost:8080/predict",
                     files={"file": open('sample_images/sample1.jpg','rb')})
print(resp.json())
出力
{'prediction': 2}

目的の機能を実現できていることが確認できました。

uWSGIの導入

ここで、今まで無視してきたアプリ起動時のwarningに注目してみます。

Environment: production
   WARNING: This is a development server. Do not use it in a production deployment.
   Use a production WSGI server instead.

これまでは (黙示的に) Flaskに組み込まれている開発サーバを使用してきましたが、本番環境へのデプロイでは使うなとのことです。

公式ドキュメントにも、

When running publicly rather than in development, you should not use the built-in development server (flask run). The development server is provided by Werkzeug for convenience, but is not designed to be particularly efficient, stable, or secure.

と記載があり、効率性、安定性、セキュリティを意識した設計となっていないことが強調されています。

そこで、代わりに利用するWSGIサーバとしてuWSGIが登場します。
とはいえ、アプリ自体は完成しているため、残りの作業はuWSGIの設定ファイルを作成するだけです。

uwsgi.ini
[uwsgi]
module = app.main
callable = app
http = 0.0.0.0:8080
master = true

次のコマンドを実行すると、uwsgi.iniを参照してuWSGIでアプリが起動されます。

uwsgi --ini uwsgi.ini

リクエストの投げ方は先ほどと同様です。

import requests

resp = requests.post("http://localhost:8080/predict",
                     files={"file": open('sample_images/sample1.jpg','rb')})
print(resp.json())
出力
{'prediction': 2}

想定通りの結果が返ってきたので、以上でアプリは完成です。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?