34
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

posted at

Python × Flask × PyTorch 数字認識Webアプリのお手軽構築

Python、Flask、PyTorchを使って画像認識アプリを作ってみます。
この3つを組み合わせれば、お手軽かつ爆速でデモアプリを作れます。

前置き

Flaskとは

Python用のWebフレームワークです。
PythonのWebフレームワークはDjangoが有名ですが、Flaskは軽量を売りにしています。Djangoに比べると機能や拡張ライブラリは少ないですが、制約がある分コードもシンプルになり、お手軽にアプリケーションを作成することができます。
環境構築も簡単にできるのでWebアプリケーションのプロトタイプを作るのに向いています。

Flaskと画像処理、機械学習との親和性の良さ

Pythonは機械学習関連のライブラリが充実しており、デファクトスタンダードになっているのは周知のことです。またPythonはOpenCVやPillow(PIL)などの画像処理ライブラリも充実しており、ネット上の情報も豊富です。
こういった背景もありPython × Flaskでやると機械学習ライブラリ、画像処理ライブラリの利用がとてもやりやすく簡単にアプリケーションを作れます。

今回のゴール

ブラウザから手書き数字画像をアップロードすると、数字を認識して結果を表示するアプリケーションを作ってみます。
2019-12-13_20h51_28.png

アプリケーションの構成

機械学習モデルの構築

今回はPyTorchを使ってMNISTの手書き数字認識モデルを作りました。

Google ColaboratoryでPyTorchでMNISTを学習したモデルを保存し、それを読み出して使う簡単サンプル - 人工知能プログラミングやってくブログ
この記事を参考に学習モデルを作ります。
動かすと1,725,616バイトのmnist_cnn.ptができました。

環境の構築

Flask環境の構築はpipがインストールしてあれば pip install Flask で一発です。
今回はPillow(PIL)、PyTorchも使っているので、これもインストールしておきます。

Webアプリの構築

ディレクトリ、ファイル構成は以下のようになります。

├── mnist_cnn.pt … 手書き数字認識モデル
├── predict.py … メインのスクリプト。ファイルのアップロードと画像判定を行う
├── static … アップロードしたファイルの配置先
│   ├── 20191213210438.png … ここにアップロードしたファイルが保存される
│   ├── 20191213210253.png
│   └── 20191213210341.png
├── templates … htmlテンプレートの保存先
     ├── index.html

predict.pyの内容です。
機械学習モデルの定義・ロードとWebアプリの処理を記述しています。
モデル定義、画像の前処理についての詳細を知りたい方は以下の記事を参照してください。
Pytorch×MNIST手書き数字認識 PNG画像を入力に予測してみる - Qiita

predict.py
# 必要なモジュールを読み込む
# Flask関連
from flask import Flask, render_template, request, redirect, url_for, abort

# PyTorch関連
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms

# Pillow(PIL)、datetime
from PIL import Image, ImageOps
from datetime import datetime

# モデルの定義
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


device = torch.device("cpu")
model = 0
model = Net().to(device)
# 学習モデルをロードする
model.load_state_dict(
    torch.load("./mnist_cnn.pt", map_location=lambda storage, loc: storage)
)
model = model.eval()

app = Flask(__name__)


@app.route("/", methods=["GET", "POST"])
def upload_file():
    if request.method == "GET":
        return render_template("index.html")
    if request.method == "POST":
        # アプロードされたファイルをいったん保存する
        f = request.files["file"]
        filepath = "./static/" + datetime.now().strftime("%Y%m%d%H%M%S") + ".png"
        f.save(filepath)
        # 画像ファイルを読み込む
        image = Image.open(filepath)
        # PyTorchで扱えるように変換(リサイズ、白黒反転、正規化、次元追加)
        image = ImageOps.invert(image.convert("L")).resize((28, 28))
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        image = transform(image).unsqueeze(0)
        # 予測を実施
        output = model(image)
        _, prediction = torch.max(output, 1)
        result = prediction[0].item()

        return render_template("index.html", filepath=filepath, result=result)


if __name__ == "__main__":
    app.run(debug=True)

index.htmlの内容です。
このHTMLテンプレートの中にファイルアップロードと認識結果表示を記述してあります。

index.html
<html>
    <body>
        {% if result %}
          <IMG SRC="{{filepath}} " BORDER="1"> 認識結果は {{result}} です<BR>
          <HR>
        {% endif %}
        ファイルを選択して送信してください<BR>
        <form action = "./" method = "POST" 
           enctype = "multipart/form-data">
           <input type = "file" name = "file" />
           <input type = "submit"/>
        </form>
     </body>
</html>

起動と動作確認

python predicy.pyを実行するとFlaskのWebサーバが起動してアプリが動き始めます。ちなみにFlaskのデフォルトポートは5000番です。
http://localhostかホスト名:5000/でアクセスするとWebアプリケーションが表示されます。
手書き数字をアップロードすると
2019-12-13_21h08_00.png
認識結果を表示してくれます。ちゃんと「9」と認識できています。
2019-12-13_21h08_15.png

まとめ

機械学習を使って画像認識するWebアプリケーションを作る、と聞くと難しそうでコードも多く複雑になりそうですが、Flaskを使えば本当にお手軽にできます。
機械学習モデルは作って終わりではなく、まずいろんな人に使ってもらってなんぼかと思います。
ただコマンドベースだと、非エンジニアの人は使えないし、いろいろ試してみるのも難しいです。
そういう時にFlaskを使ってサクッとプロトタイプを作ってしまうのがおすすめです。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
34
Help us understand the problem. What are the problem?