はじめに
この記事は、5G Edge Computing Challenge with AWS and Verison というハッカソンで作った災害判定API の作成方法を紹介します。
特に、事前学習済みモデルからAmazon SageMaker を利用して機械学習API を作成する方法について書きます。
もっと簡単に作れるだろという暖かいツッコミもお待ちしております。
災害判定API
今回作った災害判定API は、Evacuation Support App というシステムのAPI の一つです。
Evacuation Support App は超簡単にいうと、ドローンで撮影した写真から、その地点で災害が起こっているかを判定し、安全な避難経路を案内するシステムです。
災害判定API は、この「ドローンで撮影した写真から、その地点で災害が起こっているかを判定する」API です。
Amazon SageMaker
Amazon SageMaker は機械学習API を作成するためのプラットフォームサービスです。
モデルの学習から、機械学習API のデプロイまで、フルライフサイクルで活用可能なサービスです。
AWS には他にも画像認識サービスとしてAmazon Rekognition がありますが、今回は学習済みモデルを利用したAPI を作成したかったので、SageMaker を採用しました。
Amazon SageMaker 推論API 概要
SageMaker の推論API は、次の図のように4つの関数から構成されます。
input_fn
と output_fn
がAPI 化するためのインターフェースを担う部分、model_fn
と predict_fn
が推論を行うためのロジックを担当します。
input_fn
input_fn
は、クライアントからのRequest を機械学習モデルへのインプットに変換します。
実装によっては、クライアントからのリクエストとして application/json
はもちろんのこと、image/jpeg
など様々なContent Type をサポートすることができます。
今回は、画像バイナリをリクエストとして受け取り、numpy のndarray に変換します。
model_fn
model_fn
は、API 内で使用するモデルを提供します。
predict_fn
predict_fn
は、 input_fn
によって生成されたインプットと model_fn
によって提供された機械学習モデルを用いて、実際に推論します。
output_fn
output_fn
は、 predict_fn
による推論結果をクライアントへのResponse に変換します。
使用したモデル
モデルには、MIT により公開されているIncidentsDataset の学習済みモデルを利用しています。
IncidentsDataset は、SNS に投稿された災害画像から構成された、災害検出のためのデータセットです。
実装手順
ここからは、Amazon SageMaker 上で学習済みモデルから機械学習API を実装する手順を説明します。
-
input_fn
関数を実装 -
model_fn
関数を実装 -
predict_fn
関数を実装 -
output_fn
関数を実装 - SageMaker にデプロイ
構成
- Python version: 3.6
- ML Framework: PyTorch
- ML Framework version: 1.8.1
1. model_fn
関数を実装
model_fn
は単純にモデルを生成して、返却するのみです。
def model_fn(model_dir):
"""Load models"""
logger.info('START model_fn')
# Assemble arguments
parser = get_parser()
args = parser.parse_args(
args=f"--config={model_dir}/{CONFIG_FILENAME} " +
f"--checkpoint_path={model_dir}/{CHECKPOINT_PATH_FOLDER} " +
"--mode=test " +
"--num_gpus=0")
# Create models
incident_model = get_incidents_model(args, model_dir)
# Load pretrained weights
update_incidents_model_with_checkpoint(incident_model, args)
# Change mode into eval
update_incidents_model_to_eval_mode(incident_model)
logger.info('END model_fn')
return incident_model
モデルの生成に使用するファイルなど (checkpoint など) は、tar 形式で固めてS3 などに置いておき、デプロイ時に場所を明示します。
これらのファイルを model_fn
では、tar が解凍された状態で model_dir
として利用できます。
その他モデルを生成するために使用している関数は、IncidentsDataset で用意されている関数を流用しています。
model_fn
は、とにかく機械学習モデルが返却できればなんでもよいので は、参考までに。
2. input_fn
関数を実装
input_fn
は任意のContent Type のリクエストを受け、機械学習モデルのインプットに整形して返却します。
def input_fn(request_body, content_type=JSON_CONTENT_TYPE):
"""Convert request body into input of models"""
logger.info("START input_fn")
if content_type == JPEG_CONTENT_TYPE or content_type == PNG_CONTENT_TYPE or content_type == JSON_CONTENT_TYPE:
if content_type == JPEG_CONTENT_TYPE or content_type == PNG_CONTENT_TYPE:
f = BytesIO(request_body)
else:
# Get image from S3
s3 = boto3.resource("s3")
bucket = s3.Bucket(request_body["s3_bucket_name"])
obj = bucket.Object(request_body["s3_object_name"])
response = obj.get()
f = BytesIO(response["Body"].read())
try:
input_data = Image.open(f).convert('RGB')
input_data = inference_loader(input_data)
except:
input_data = Image.new('RGB', (300, 300), 'white')
input_data = inference_loader(input_data)
input_data = torch.unsqueeze(input_data, 0)
else:
logger.error(f'Content-Type invalid: {content_type}')
input_data = {'errors': [f'Content-Type invalid: {content_type}']}
logger.info("END input_fn")
return input_data
上の例では、 image/jpeg
、 image/png
、 application/json
をサポートします。
image/jpeg
、 image/png
は、バイナリを直接ndarray に変換します。
application/json
は、S3 からリクエストで指定された画像を取得し、ndarray に変換します。
request_body の型がバイト配列であることに注意
3. predict_fn
関数を実装
predict_fn
は、インプット、機械学習モデルを用いて推論します。
def predict_fn(input_data, model):
"""Predict from input"""
logger.info("START predict_fn")
if isinstance(input_data, dict) and "errors" in input_data:
return input_data
trunk_model, incident_layer, _ = model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_on_device = input_data.to(device)
output = trunk_model(input_on_device)
incident_output = incident_layer(output)
incident_output = F.softmax(incident_output, dim=1)
logger.info("END predict_fn")
return incident_output
特に説明することもないのですが、引数の input_data
は input_fn
によって生成されたインプットです。
model
は model_fn
によって提供される機械学習モデルです。
predict_fn
でどこまでやるかは、実装者の裁量によりますが、ここでは機械学習モデルによる推論のみ行っています。
4. output_fn
関数を実装
output_fn
は、推論結果をAPI のレスポンスに整形して返却します。
def output_fn(prediction, accept=JSON_CONTENT_TYPE):
"""Convert output of models into response body"""
logger.info("START output_fn")
logger.info(f"Accept: {accept}")
if isinstance(prediction, dict) and "errors" in prediction:
logger.info("SKIP output_fn")
response = json.dumps(prediction)
content_type = JSON_CONTENT_TYPE
else:
incident_map = get_index_to_incident_mapping()
topk = 3
incident_threshold = 0.5
incident_probs, incident_idx = prediction.sort(1, True)
incidents = []
probs = incident_probs[0].cpu().detach().numpy()[:topk].tolist()
if probs[0] < incident_threshold:
incidents.append("no incident")
else:
for idx in incident_idx[0].cpu().numpy()[:topk]:
if idx < len(incident_map):
incidents.append(incident_map[idx])
else:
incidents.append("no incident")
response = json.dumps({"results": {"incidents": incidents, "probs": probs}})
content_type = JSON_CONTENT_TYPE
logger.info("END output_fn")
return response, content_type
ごちゃごちゃと色々やっていますが、結局は推論結果をJSON 形式に変換しているだけです。
その過程で、推論結果のベクトルから対応する災害を表す文字列に変換していたりします。
5. SageMaker でデプロイ
ここまでできれば、あとはデプロイするのみです。
面倒くさいので 諸事情により、スクリーンショットなどは省きますが、雰囲気だけでも伝えられればと思います。
-
SageMaker でノートブックインスタンスを立てる
-
ノートブックインスタンスに、上記のソースを含んだGit リポジトリをクローンする
-
新規のJupyter Notebook を作成する
-
SageMaker セッションを開始する
sagemaker_session = sagemaker.Session() role = sagemaker.get_execution_role()
-
PyTorch モデルを作成する
from sagemaker.pytorch.model import PyTorchModel # Make model pytorch_model = PyTorchModel(model_data="s3://hoge/model.tar.gz", # モデル生成に必要なファイルを含むtar role=role, framework_version="1.8.1", py_version="py36", source_dir="src", # スクリプトを含むディレクトリ entry_point="entrypoint.py" # input_fn, model_fn などを含むPython スクリプト )
-
モデルをデプロイ
deploy_params = { "instance_type": "ml.t2.medium", "initial_instance_count": 1 } predictor = pytorch_model.deploy(**deploy_params)
まとめ
今回はAmazon SageMaker で、災害検知API を作ってみました。
実装すべき関数が既に決まっていて、こちらで中身を実装するだけで機械学習API が実装できるのはお手軽でですね。
次は、学習を回したり、継続的学習のためのパイプラインを作ってみたりしたいですね。