0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Amazon SageMaker で事前学習済みモデルから災害判定APIを作る

Last updated at Posted at 2021-07-17

はじめに

この記事は、5G Edge Computing Challenge with AWS and Verison というハッカソンで作った災害判定API の作成方法を紹介します。
特に、事前学習済みモデルからAmazon SageMaker を利用して機械学習API を作成する方法について書きます。
もっと簡単に作れるだろ:rage:という暖かいツッコミもお待ちしております。

災害判定API:fire_engine:

今回作った災害判定API は、Evacuation Support App というシステムのAPI の一つです。
Evacuation Support App は超簡単にいうと、ドローンで撮影した写真から、その地点で災害が起こっているかを判定し、安全な避難経路を案内するシステムです。
災害判定API は、この「ドローンで撮影した写真から、その地点で災害が起こっているかを判定する」API です。

Amazon SageMaker

Amazon SageMaker は機械学習API を作成するためのプラットフォームサービスです。
モデルの学習から、機械学習API のデプロイまで、フルライフサイクルで活用可能なサービスです。
AWS には他にも画像認識サービスとしてAmazon Rekognition がありますが、今回は学習済みモデルを利用したAPI を作成したかったので、SageMaker を採用しました。

Amazon SageMaker 推論API 概要

SageMaker の推論API は、次の図のように4つの関数から構成されます。

function-flow.drawio.png

input_fnoutput_fn がAPI 化するためのインターフェースを担う部分、model_fnpredict_fn が推論を行うためのロジックを担当します。

input_fn

input_fn は、クライアントからのRequest を機械学習モデルへのインプットに変換します。
実装によっては、クライアントからのリクエストとして application/json はもちろんのこと、image/jpeg など様々なContent Type をサポートすることができます。
今回は、画像バイナリをリクエストとして受け取り、numpy のndarray に変換します。

model_fn

model_fn は、API 内で使用するモデルを提供します。

predict_fn

predict_fn は、 input_fn によって生成されたインプットと model_fn によって提供された機械学習モデルを用いて、実際に推論します。

output_fn

output_fn は、 predict_fn による推論結果をクライアントへのResponse に変換します。

使用したモデル

モデルには、MIT により公開されているIncidentsDataset の学習済みモデルを利用しています。
IncidentsDataset は、SNS に投稿された災害画像から構成された、災害検出のためのデータセットです。

実装手順

ここからは、Amazon SageMaker 上で学習済みモデルから機械学習API を実装する手順を説明します。

  1. input_fn 関数を実装
  2. model_fn 関数を実装
  3. predict_fn 関数を実装
  4. output_fn 関数を実装
  5. SageMaker にデプロイ

構成

  • Python version: 3.6
  • ML Framework: PyTorch
  • ML Framework version: 1.8.1

1. model_fn 関数を実装

model_fn は単純にモデルを生成して、返却するのみです。

    def model_fn(model_dir):
        """Load models"""
        logger.info('START model_fn')
    
        # Assemble arguments
        parser = get_parser()
        args = parser.parse_args(
            args=f"--config={model_dir}/{CONFIG_FILENAME} " +
            f"--checkpoint_path={model_dir}/{CHECKPOINT_PATH_FOLDER} " +
            "--mode=test " +
            "--num_gpus=0")
    
        # Create models
        incident_model = get_incidents_model(args, model_dir)
        # Load pretrained weights
        update_incidents_model_with_checkpoint(incident_model, args)
        # Change mode into eval
        update_incidents_model_to_eval_mode(incident_model)
    
        logger.info('END model_fn')
        return incident_model

モデルの生成に使用するファイルなど (checkpoint など) は、tar 形式で固めてS3 などに置いておき、デプロイ時に場所を明示します。
これらのファイルを model_fn では、tar が解凍された状態で model_dir として利用できます。
その他モデルを生成するために使用している関数は、IncidentsDataset で用意されている関数を流用しています。
model_fn は、とにかく機械学習モデルが返却できればなんでもよいので :point_up: は、参考までに。

2. input_fn 関数を実装

input_fn は任意のContent Type のリクエストを受け、機械学習モデルのインプットに整形して返却します。

    def input_fn(request_body, content_type=JSON_CONTENT_TYPE):
        """Convert request body into input of models"""
        logger.info("START input_fn")
        if content_type == JPEG_CONTENT_TYPE or content_type == PNG_CONTENT_TYPE or content_type == JSON_CONTENT_TYPE:
            if content_type == JPEG_CONTENT_TYPE or content_type == PNG_CONTENT_TYPE:
                f = BytesIO(request_body)
            else:
                # Get image from S3
                s3 = boto3.resource("s3")
                bucket = s3.Bucket(request_body["s3_bucket_name"])
                obj = bucket.Object(request_body["s3_object_name"])
                response = obj.get()
                f = BytesIO(response["Body"].read())
            try:
                input_data = Image.open(f).convert('RGB')
                input_data = inference_loader(input_data)
            except:
                input_data = Image.new('RGB', (300, 300), 'white')
                input_data = inference_loader(input_data)
            input_data = torch.unsqueeze(input_data, 0)
        else:
            logger.error(f'Content-Type invalid: {content_type}')
            input_data = {'errors': [f'Content-Type invalid: {content_type}']}
    
        logger.info("END input_fn")
        return input_data

上の例では、 image/jpegimage/pngapplication/json をサポートします。
image/jpegimage/png は、バイナリを直接ndarray に変換します。
application/json は、S3 からリクエストで指定された画像を取得し、ndarray に変換します。

request_body の型がバイト配列であることに注意

3. predict_fn 関数を実装

predict_fn は、インプット、機械学習モデルを用いて推論します。

    def predict_fn(input_data, model):
        """Predict from input"""
        logger.info("START predict_fn")
        if isinstance(input_data, dict) and "errors" in input_data:
            return input_data
    
        trunk_model, incident_layer, _ = model
    
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        input_on_device = input_data.to(device)
        output = trunk_model(input_on_device)
        incident_output = incident_layer(output)
        incident_output = F.softmax(incident_output, dim=1)
    
        logger.info("END predict_fn")
    
        return incident_output

特に説明することもないのですが、引数の input_datainput_fn によって生成されたインプットです。
modelmodel_fn によって提供される機械学習モデルです。
predict_fn でどこまでやるかは、実装者の裁量によりますが、ここでは機械学習モデルによる推論のみ行っています。

4. output_fn 関数を実装

output_fn は、推論結果をAPI のレスポンスに整形して返却します。

    def output_fn(prediction, accept=JSON_CONTENT_TYPE):
        """Convert output of models into response body"""
        logger.info("START output_fn")
        logger.info(f"Accept: {accept}")
    
        if isinstance(prediction, dict) and "errors" in prediction:
            logger.info("SKIP output_fn")
            response = json.dumps(prediction)
            content_type = JSON_CONTENT_TYPE
        else:
            incident_map = get_index_to_incident_mapping()
            topk = 3
            incident_threshold = 0.5
    
            incident_probs, incident_idx = prediction.sort(1, True)
            incidents = []
            probs = incident_probs[0].cpu().detach().numpy()[:topk].tolist()
            if probs[0] < incident_threshold:
                incidents.append("no incident")
            else:
                for idx in incident_idx[0].cpu().numpy()[:topk]:
                    if idx < len(incident_map):
                        incidents.append(incident_map[idx])
                    else:
                        incidents.append("no incident")
    
            response = json.dumps({"results": {"incidents": incidents, "probs": probs}})
            content_type = JSON_CONTENT_TYPE
    
        logger.info("END output_fn")
    
        return response, content_type

ごちゃごちゃと色々やっていますが、結局は推論結果をJSON 形式に変換しているだけです。
その過程で、推論結果のベクトルから対応する災害を表す文字列に変換していたりします。

5. SageMaker でデプロイ

ここまでできれば、あとはデプロイするのみです。
面倒くさいので 諸事情により、スクリーンショットなどは省きますが、雰囲気だけでも伝えられればと思います。

  1. SageMaker でノートブックインスタンスを立てる

  2. ノートブックインスタンスに、上記のソースを含んだGit リポジトリをクローンする

  3. 新規のJupyter Notebook を作成する

  4. SageMaker セッションを開始する

    sagemaker_session = sagemaker.Session()
    role = sagemaker.get_execution_role()
    
  5. PyTorch モデルを作成する

    from sagemaker.pytorch.model import PyTorchModel
    
    # Make model
    pytorch_model = PyTorchModel(model_data="s3://hoge/model.tar.gz",  # モデル生成に必要なファイルを含むtar
                                 role=role,
                                 framework_version="1.8.1",
                                 py_version="py36",
                                 source_dir="src",  # スクリプトを含むディレクトリ
                                 entry_point="entrypoint.py"  # input_fn, model_fn などを含むPython スクリプト
                                 )
    
  6. モデルをデプロイ

    deploy_params = {
        "instance_type": "ml.t2.medium",
        "initial_instance_count": 1
    }
    predictor = pytorch_model.deploy(**deploy_params)
    

まとめ

今回はAmazon SageMaker で、災害検知API を作ってみました。
実装すべき関数が既に決まっていて、こちらで中身を実装するだけで機械学習API が実装できるのはお手軽で:thumbsup:ですね。
次は、学習を回したり、継続的学習のためのパイプラインを作ってみたりしたいですね。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?