17
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

AWS SageMakerでモデルのデプロイから推論(バッチ変換ジョブ)までを行う

Posted at

本記事でやること

  • 前回の記事で行なったトレーニングジョブ結果から生成されたモデルのデプロイを行う。
  • S3に置いてあるデータをまとめて推論するバッチ変換ジョブを実行する。

本記事でやらないことは以下の通りなので、他の記事を参照してください。

  • 前回の記事で行なった独自アルゴリズムを使ったトレーニングジョブの実行方法
  • SageMakerやS3などにアクセスするためのIAMロールの作成

対象読者

  • とりあえずAWS SageMakerを動かしてみたい人
  • Dockerに関して一通りの基礎知識がある人

使用言語

  • Python 3.6.3

トレーニングジョブ結果から生成されたモデルのデプロイを行う

公式ドキュメントに記載がある通り、推論を行う際にアプリケーションがリクエストを送るエンドポイントは、モデルをデプロイした際に取得されます。なので、トレーニングジョブが完了された後にモデルのデプロイを必ず行う必要があります。

以下、コードからモデルのデプロイを行います。


class SagemakerClient:

    def __init__(self):
        self.client = Session(profile_name="hoge").client("sagemaker", region_name="ap-northeast-1")

    def create_model(self, model_data_url):

        model_params = {
            "ExecutionRoleArn": "arn:aws:iam::123:role/dev-sagemaker", 
            "ModelName": "sample-model", # モデル名
            "PrimaryContainer": {
                "Image": "123.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-repo:latest", # ECRにプッシュしたイメージURL
                "ModelDataUrl": model_data_url # モデルデータが格納されているS3のパス
            }
        }

        self.client.create_model(**model_params)


if __name__ == '__main__':
    model_data_url = ...
    SagemakerClient().create_model(model_data_url)

上記コードでのcreate_modelメソッドの引数(model_data_url)は、boto3のdescribe_training_jobから得られるのでこちらの公式ドキュメントをご参照ください。

参考程度にmodel_data_urlを取得するコードを以下に記載いたします。また、describe_training_jobの引数TrainingJobNameはトレーニングジョブを送信する際に返ってくるTransformJobArnから取得をしています。

client = Session(profile_name="hoge").client("sagemaker", region_name="ap-northeast-1")

model_data_url = client.describe_training_job(TrainingJobName=training_job_name)['ModelArtifacts']['S3ModelArtifacts']

以下のようにAmazon SageMaker > モデルにモデルが現れていれば無事完了です。

スクリーンショット 2019-01-27 16.50.05.png

S3に置いてあるデータをまとめて推論するバッチ変換ジョブを実行

今回、推論時に使用したデータや推論結果の出力先などは以下のようなフォルダ構成でS3に置きました。

bucket
├── input-data-prediction
│     └── YYYY-MM-DD
│          └── multiclass
│               └── iris.csv
├── output-data-prediction
│     └── YYYY-MM-DD
│          └── multiclass
│               
├── input-data-training
│     └── YYYY-MM-DD
│          └── multiclass
│               └── iris.csv
└── output-model
      └── YYYY-MM-DD
           └── multiclass
                └── model名(←ここからはSageMakerが出力する)

以下のコードからバッチ変換ジョブを送信します。
ジョブを送信する際の引数であるModelName は、デプロイしたモデル名を使用します。
デプロイしたモデル名は、boto3のlist_modelsメソッドで得られるのでこちらの公式ドキュメントをご参照ください。(モデル名に含まれている文字列とデプロイした時間からしか検索ができないみたいです)


class SagemakerClient:

    def __init__(self):
        self.client = Session(profile_name="hoge").client("sagemaker", region_name="ap-northeast-1")

    def submit_transform_job(self):
        
        model_name = self.client.list_models(
            NameContains="base_model_name", # 各自デプロイしたモデル名に含まれている文字列
            SortOrder='Descending',
            SortBy='CreationTime')["Models"][0]["ModelName"]
        
        transform_params = {
            "TransformJobName": "sample-transform_job_name", # バッチ変換ジョブ名
            "ModelName": model_name, # デプロイしたモデル名
            "MaxConcurrentTransforms": 2,
            "MaxPayloadInMB": 50,
            "BatchStrategy": "MultiRecord",
            "TransformOutput": {
                "S3OutputPath": "s3://bucket/output-data-prediction/YYYY-MM-DD/multiclass/" # 推論結果を格納するS3パス
            },
            "TransformInput": {
                "DataSource": {
                    "S3DataSource": {
                        "S3DataType": "S3Prefix",
                        "S3Uri": "s3://bucket/input-data-prediction/YYYY-MM-DD/multiclass/" # 推論を行うインプットデータが格納されているS3パス
                    }
                },
                "ContentType": "text/csv",
                "SplitType": "Line"
            },
            "TransformResources": {
                "InstanceType": "ml.c4.xlarge",
                "InstanceCount": 1
            }
        }

        self.client.create_transform_job(**transform_params)

if __name__ == '__main__':
    SagemakerClient().submit_transform_job()

以下のようにAmazon SageMaker > バッチ変換ジョブでステータスがCompletedになれば無事完了です。上記で指定したS3のパスに推論結果が格納されているはずです。

スクリーンショット 2019-01-27 19.02.53.png

終わりに

今回は、モデルをデプロイ~バッチ変換ジョブを送信するコードのみを記載しましたが、SageMakerのDockerコンテナ内にあるpredictor.pyserve公式のgithubにあるサンプルをそのまま使用しています。

17
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?