概要
PyTorchのモデルをFlaskでデプロイします。画像をBase64エンコードで送信する方式を取ります。
imagenet_class_index.json
はこちらからダウンロードしました。
sample.jpg
は分類を試す画像で、なんでもいいです。
コード
フォルダ構造
.
├── Dockerfile
├── client.py
├── docker-compose.yml
├── imagenet_class_index.json
├── main.py
├── model.py
├── requirements.txt
└── sample.jpg
requirements.txt
numpy
pillow
flask
flask_cors
requests
torch
torchvision
gunicorn
サーバーのコード
main.py
import base64
import json
from io import BytesIO
from flask import Flask, jsonify, request
from flask_cors import CORS
from PIL import Image
from model import Predictor
app = Flask(__name__)
CORS(app)
predictor = Predictor('imagenet_class_index.json')
@app.route("/", methods=["POST"])
def predict():
data = request.data.decode('utf-8')
data = json.loads(data)
img_stream = base64.b64decode(data['image'])
img_pil = Image.open(BytesIO(img_stream))
idx, label = predictor.predict(img_pil)
response = {'result': label}
return jsonify(response)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000, debug=True)
model.py
import json
import shutil
import numpy as np
import requests
import torch
import torch.nn.functional as F
from torchvision import models, transforms
def get_weight():
url = 'https://s3xxxxxxxxxxxxxx'
try:
res = requests.get(url)
with open('./weight.pth', mode='wb') as f:
shutil.copyfileobj(res.raw, f)
except Exception as e:
raise RuntimeError(e)
def create_net(param_file=None, device="cpu"):
net = models.resnet34(pretrained=True)
if param_file:
net.load_state_dict(torch.load(
param_file,
map_location=torch.device(device))
)
net.to(device)
return net
def load_label(label_path):
with open(label_path) as f:
labels = json.load(f)
return labels
class Predictor():
def __init__(self, label_path, param_file=None, device="cpu"):
self.net = create_net(param_file, device="cpu")
self.net.eval()
self.trans = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
self.classes = load_label(label_path)
@torch.no_grad()
def predict(self, img_pil):
"""
Parameters
----------
img_pil : PIL.Image.Image
"""
img_tensor = self.trans(img_pil)
img_tensor = torch.unsqueeze(img_tensor, 0)
y_pred = self.net(img_tensor)
y_pred = F.softmax(y_pred, dim=1).detach().numpy()
y_idx = np.argmax(y_pred)
return y_idx, self.classes[str(y_idx)][1]
Dockerfile
FROM python:3.9
ENV PYTHONUNBUFFERED 1
WORKDIR /usr/src/app
COPY requirements.txt /usr/src/app/
RUN pip install -r requirements.txt
COPY . /usr/src/app/
CMD ["gunicorn", "main:app", "-b", "0.0.0.0:5000", "--access-logfile", "-", "-w", "2", "--preload"]
docker-compose.yml
version: '2.3'
services:
worker:
build:
context: ./
container_name: flask_app
ports:
- "5000:5000"
working_dir: /usr/src/app/
クライアントのコード
client.py
import base64
import os
import requests
os.environ['no_proxy'] = '*'
# Base64でエンコードする画像のパス
target_file = "sample.jpg"
with open(target_file, 'rb') as f:
data = f.read()
# Base64で画像をエンコード
encoded = base64.b64encode(data).decode('utf-8')
url = 'http://localhost:5000'
payload = {'image': encoded}
res = requests.post(url, json=payload)
print(res.text)
実行
docker compose up
でサーバーのコンテナ起動。
python client.py
を実行して画像を送信、分類結果を受信。