5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch on Lambda で MNIST の推論を実行する API を作成する(AWS CDK)

Posted at

はじめに

Docker イメージが Lambda で実行できるようになり、イメージも 10GB までとなったため、機械学習の推論 API のロジックを Lambda で実行できるかを試してみました。
今回は PyTorch で作った MNIST のモデルを置いてみます。

環境

  • Python 3.8
  • PyTorch 1.6.0 CPU
  • PyTorch 1.1.8

API Gateway + Lambda で画像を受け取る

まずは、API なので画像を受け取って Lambda で処理できる形まで持っていく必要があります。こちらは長くなったので別記事に記載します。以下の記事の内容が完了していることを前提とします。
画像をAPI Gateway+Lambdaで受け取って Pillow で処理する(AWS CDK) - Qiita

モデル作成

この記事で詳細は解説しませんが、以下のように別途ディレクトリーを作成して PyTorch モデルの重みを保存したファイルを作成します。

$ mkdir models
$ cd models
$ touch train.py

以下のパッケージをインストールします。

  • torch
  • torchvision

学習用の Python ファイルを作成します。

models/train.py
from pytorch_lightning.metrics.functional import accuracy
import pytorch_lightning as pl
from torchvision import transforms, datasets
import torch.nn.functional as F
import torch.nn as nn
from torchvision import datasets
import torch

import torchvision
from torchvision import transforms


transform = transforms.Compose([
    transforms.ToTensor()
])
train_val = datasets.MNIST(
    './', train=True, download=True, transform=transform)
test = datasets.MNIST('./', train=False, download=True, transform=transform)

n_train, n_val = 50000, 10000
train, val = torch.utils.data.random_split(train_val, [n_train, n_val])

batch_size = 1028

train_loader = torch.utils.data.DataLoader(
    train, batch_size, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val, batch_size)
test_loader = torch.utils.data.DataLoader(test, batch_size)


class Net(pl.LightningModule):

    def __init__(self):
        super().__init__()

        self.conv = nn.Conv2d(in_channels=1, out_channels=3,
                              kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(3)
        self.fc = nn.Linear(588, 10)

    def forward(self, x):
        h = self.conv(x)
        h = F.relu(h)
        h = self.bn(h)
        h = F.max_pool2d(h, kernel_size=2, stride=2)
        h = h.view(-1, 588)
        h = self.fc(h)
        return h

    def training_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log('train_loss', loss, on_step=True,
                 on_epoch=True, prog_bar=True)
        self.log('train_acc', accuracy(y, t), on_step=True,
                 on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', accuracy(y, t), on_step=False, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        self.log('test_loss', loss, on_step=False, on_epoch=True)
        self.log('test_acc', accuracy(y, t), on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
        return optimizer


net = Net()
trainer = pl.Trainer(max_epochs=5, gpus=0, deterministic=True)
trainer.fit(net, train_loader, val_loader)
results = trainer.test(test_dataloaders=test_loader)
torch.save(net.state_dict(), 'mnist.pt')

作成した学習用 Python ファイルを実行します。

$ python train.py
$ ls
MNIST           lightning_logs  mnist.pt       train.py

学習の実行が成功すると、mnist.pt ファイルが生成されます。

Lambda 構築

別記事で解説した内容とほぼ同じため差分を解説します。

Dockerfile 変更

必要なパッケージをインストールする必要があるため、Dockerfile を変更します。

src/Dockerfile
FROM public.ecr.aws/lambda/python:3.8
RUN pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html && pip install pillow && pip install pytorch-lightning
COPY mnist.pt ./
COPY app.py ./
CMD [ "app.handler" ]

モデル配置

学習して作成したモデルファイルを src ディレクトリーに配置します。

# プロジェクトの Root ディレクトリーに戻る
$ cd ..
$ cp models/mnist.pt src/mnist.pt

Lambda コード変更

app.py を以下のように変更します。
Pillow で読み込んだファイルを PyTorch のモデル + 事前に学習した重みで推論して結果を返すようにします。

src/app.py
import base64
from io import BytesIO
import torch
from torchvision import transforms
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import json
from PIL import Image


class Net(pl.LightningModule):

    def __init__(self):
        super().__init__()

        self.conv = nn.Conv2d(in_channels=1, out_channels=3,
                              kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(3)
        self.fc = nn.Linear(588, 10)

    def forward(self, x):
        h = self.conv(x)
        h = F.relu(h)
        h = self.bn(h)
        h = F.max_pool2d(h, kernel_size=2, stride=2)
        h = h.view(-1, 588)
        h = self.fc(h)
        return h


def handler(event, context):
    data = event.get('body', '')
    data = BytesIO(base64.b64decode(data))
    image = Image.open(data)

    net = Net().cpu().eval()
    net.load_state_dict(torch.load(
        'mnist.pt', map_location=torch.device('cpu')))

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    x = transform(image)

    y = net(x.unsqueeze(0))
    y = F.softmax(y)
    y = torch.argmax(y)

    return {
        'statusCode': 200,
        'body': json.dumps({
            'number': '{}'.format(y.item()),
        }),
    }

デプロイ

デプロイは cdk を使って行います。

# プロジェクトの Root ディレクトリーに戻ります
$ cd ..
$ cdk deploy

実行

デプロイに成功したら、Insomnia から数字の書かれた画像ファイルを送信してみます。

最初はこちらの 1 の画像を送信してみます。
39.png

初回起動は時間がかかりますが、2 回目以降は予想以上の速度でレスポンスが返ってきます。
Screenshot 2021-02-15 at 13.44.14.png

  • 5

152.png

Screenshot 2021-02-15 at 13.46.13.png

  • 8

401.png

Screenshot 2021-02-15 at 13.48.41.png

おわりに

PyTorch + API Gateway + Lambda による推論 API の構築方法を解説しました。

mnist であれば十分動作する API になったのではないかと思います。
mnist.pt は 26KB と小さかったため、もう少し大きなモデルを構築した場合にどのようになるのかは別途試してみたいと思います。

5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?