2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch + Flask + uWSGIによるシンプル機械学習APIの開発 (モデル学習編)

Last updated at Posted at 2020-12-10

はじめに

機械学習モデルを作るところまでのチュートリアルは世の中にあふれていますが、作ったモデルを実際どう使うの?という部分になると途端に情報が少なくなりますよね。

ここでは、自身の備忘の意味も込めて、「手書き文字認識」タスクをサンプルに、モデルの構築からデプロイまでの手順の一例を一気通貫に紹介していきます。

シンプルさを重視しているため実際のサービスに乗せるには心もとない点も多いですが、全体的なイメージをつかむきっかけになればうれしいです。

長くなったため前後編に分けており、この記事は1の部分を対象にした前編です。

  1. PyTorchによるモデルの学習 (モデル学習編)
  2. Flask + uWSGIによるAPIの実装 (アプリ実装編)

なお、本記事におけるモデル学習のコードはPyTorchの公式チュートリアルをベースにしています。

間違いやより良い方法など、お気づきの点があればぜひお気軽にコメントください。

環境構築

実行環境はWSL上のUbuntu18.04です。
また、本記事内のコードはすべてJupyter Notebookで実行しています。

まずは、今回のプロジェクト用に新しくディレクトリを作成します。

!mkdir Digit_Recognizer

続いて、仮想環境を準備します。今回はPython公式も推奨しているPipenv (公式ドキュメント)を使って仮想環境を構築します。
プロジェクトごとに環境を分けることにより、インストールされているパッケージが明確になり、予期せぬ不具合を防いだり、全く同じ環境を再現することが簡単になるというメリットがあります。

# pipenvのインストール
!pip install pipenv
# python3で仮想環境を初期化
%cd Digit_Recognizer/
!pipenv --python 3  
# パッケージのインストール (開発時にのみ利用するものはdevを指定し別管理)
!pipenv install numpy torch torchvision Pillow
!pipenv install --dev matplotlib

プロジェクトのディレクトリに移動後、pipenvコマンドを実行すると、自動で仮想環境が作成され、Pipfileというファイルが作られます。
Pipfileには、pipenvを使ってインストールしたパッケージの名称が追加されていきます。
同時にPipfile.lockが自動で生成され、実際にインストールされたパッケージの詳細なバージョンや依存パッケージの情報などが記録されます。

同じ環境を作成したい場合は、pipenv installで一括インストールできるため簡単に再現可能です。

データの準備

続いて、モデル学習用のデータを準備します。

今回は0~9の正解ラベルつき手書き数字データセットとして有名なMNISTを利用し、手書き文字認識タスク用のモデルを構築します。
ファイルのみダウンロードすることもできますが、torchvisionを利用することで簡単にPyTorchで利用可能な形式に変換できるため、そちらの方法を採用します。

DatasetやDataloaderって何?という方はこちらの記事をご覧ください。

# MNISTデータセットのダウンロード
import torch
import torchvision
import torchvision.transforms as transforms


transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])  # 正規化

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=8,
                                         shuffle=False, num_workers=2)

学習用のDataloaderからミニバッチ1つ分のデータをサンプリングし、画像とラベルを確認します。

import matplotlib.pyplot as plt
import numpy as np


def imshow(img):
    img = img * 0.5 + 0.5  # 表示用に正規化処理をリセット
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()


# 学習用データからサンプリング
dataiter = iter(trainloader)
images, labels = dataiter.next()

# 画像と正解ラベルの表示
imshow(torchvision.utils.make_grid(images))
print(' '.join(f'{labels[j].numpy(): 5}' for j in range(8)))

image.png

出力
0    5    4    9    2    1    1    4

モデルの学習

今回はCNNを利用することにし、ネットワークを以下のように定義します。
入出力が合っていればここは何でも構いません。

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)

        return x


net = Net()

続いて、損失関数とオプティマイザを定義します。

import torch.optim as optim


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

材料がすべて揃ったので、モデルを学習します。

for epoch in range(2):

    running_loss = 0.0
    for i, data in enumerate(trainloader):
        # インプットデータの取得 (dataは[inputs, labels]のリスト)
        inputs, labels = data

        # 勾配をゼロで初期化
        optimizer.zero_grad()

        # 順伝播 + 誤差逆伝播 + 重み更新
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 途中経過の表示
        running_loss += loss.item()
        if i % 2000 == 1999:    # 2,000ミニバッチごとに表示
            print(f'epoch: {epoch + 1} - iteration: {i + 1} - loss: {running_loss / 2000: .3f}')
            running_loss = 0.0
出力
epoch: 1 - iteration: 2000 - loss: 0.177
epoch: 1 - iteration: 4000 - loss: 0.152
epoch: 1 - iteration: 6000 - loss: 0.126
epoch: 2 - iteration: 2000 - loss: 0.100
epoch: 2 - iteration: 4000 - loss: 0.090
epoch: 2 - iteration: 6000 - loss: 0.089

学習したモデルの精度をテスト用データで確認します。

net.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
出力
Accuracy of the network on the 10000 test images: 97.66%

うまく学習できているようです。
学習したモデルのパラメータはプロジェクトディレクトリ下に保存しておきます。

PATH = './mnist_net.pth'
torch.save(net.state_dict(), PATH)

最後に、APIの動作確認用としてサンプルデータをいくつか保存し、ダウンロードしたMNISTのデータはフォルダごと削除しておきます。

!mkdir sample_images
!rm -r data
from PIL import Image


dataiter = iter(testloader)
iamges, _ = dataiter.next()
for i, image in enumerate(images):
    image = (image * 0.5 + 0.5) * 255
    image = image.numpy().astype('uint8').squeeze(0)
    Image.fromarray(image).save(f'sample_images/sample{i}.jpg')

モデル学習編は以上です。
次回は学習したモデルを使って、受け取った手書き文字画像に書かれた数字を予測し、結果を返すAPIを実装していきます。

アプリ実装編へ

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?