0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

SageMakerのリアルタイム推論エンドポイント(Pytorch)のリクエスト時にカスタム属性を渡す

Last updated at Posted at 2023-12-27

はじめに

SageMakerのリアルタイム推論エンドポイント(フレームワーク:Pytorch)に対して推論リクエストを行う際に、Content-Typeヘッダーやリクエストボディのような必須の情報とは別に、追加の情報(以下カスタム属性と呼びます)を付与してリクエストしたいという場面がありました。

SageMakerのリアルタイム推論エンドポイント(フレームワーク:Pytorch)にカスタム属性を渡す方法について、AWSの公式ドキュメントには明確な記載がないのですが、色々試行錯誤しているうちに渡すことができたので、まとめます。

結論

クライアント側では、以下のようにリクエストします。

import boto3

sagemaker_runtime_client = boto3.client("sagemaker-runtime")

response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName="<リアルタイム推論エンドポイントの名前>",
    ContentType="<リクエストボディのMIMEタイプ>",
    Body=<リクエストボディに格納するデータ>,
    Accept="<レスポンスボディのMIMEタイプ>",
    CustomAttributes=<カスタム属性> # カスタム属性として何かしらのデータを渡す
)

リアルタイム推論エンドポイント側では、推論コードの関数(model_fn, input_fn, predict_fn, output_fnのどれでもOK)内で、以下のようにカスタム属性を取得します。

# input_fnの例
def input_fn(request_body, request_content_type, context):
    # 推論リクエストの全ヘッダー情報を辞書型で取得する
    request_header_dir = context.get_all_request_header(0)
    
    # カスタム属性を取得する
    custom_attributes = request_header_dir["X-Amzn-SageMaker-Custom-Attributes"]

前提

SageMakerのリアルタイム推論エンドポイント(フレームワーク:Pytorch)は、
AWS Deep Learning Containersで提供されているコンテナイメージを使用して立てています。

具体的には、Boto3で以下のようなコードを書いてデプロイしています。

import boto3
from sagemaker import image_uris

aws_region = "ap-northeast-1"

# AWS Deep Learning Containersで提供されているコンテナイメージのURIを取得する
container = image_uris.retrieve(
    region=aws_region, 
    framework="pytorch", 
    version="1.12.1",
    image_scope="inference",
    instance_type="ml.g4dn.xlarge",
    py_version="py38"
)

sagemaker_client = boto3.client("sagemaker", region_name=aws_region)

# コンテナイメージのURIを指定してSageMakerモデルを作成する
sagemaker_client.create_model(
    ModelName = "<作成するSageMakerモデルの名前(任意)>",
    ExecutionRoleArn = "<SageMakerの実行ロールのARN>",
    PrimaryContainer = {
        "Image": container,
        "ModelDataUrl": "<S3にアップロードしたモデルアーティファクトのURL>"
    }
)

# 以下、エンドポイント設定を作成し、エンドポイントを立てる(省略)

公式ドキュメントに書かれていること

クライアント側について

Amazon SageMaker の InvokeEndpoint API

Amazon SageMakerのInvokeEndpoint API にはX-Amzn-SageMaker-Custom-Attributesヘッダーを設定することができ、推論のリクエストに関する追加情報を渡すことができます。

Boto3 の invoke_endpoint メソッド

Boto3 の invoke_endpoint メソッドには、上記 InvokeEndpoint API のX-Amzn-SageMaker-Custom-Attributesヘッダーに対応するCustomAttributesという引数が用意されています。

リアルタイム推論エンドポイント側について

Amazon SageMaker Python SDK の Pytorch モデルのデプロイに関するドキュメントには、X-Amzn-SageMaker-Custom-Attributesヘッダーに設定された情報を取得するための方法等は記載されていません。

一方で、推論コード内に定義する関数(model_fn, input_fn, predict_fn, output_fn)でcontextというオプションの引数を受け取ることができると記載されています。

contextの詳細については、「serve(PytorchでモデルサービングするためのPythonパッケージ)のContextクラスのコードを読め」と記載されています。

serveパッケージのContextクラスのコードをよく読んでみると、get_all_request_header()というそれっぽい名前のメソッドがありました。

get_all_request_header()int型の引数を1つ受け取るようなので、とりあえず0を渡してみると、期待するもの(推論リクエストの全ヘッダーの情報)がたまたま取得できた、というオチです。

余談

Amazon SageMaker Python SDK の TensorFlow モデルのデプロイに関するドキュメントには、X-Amzn-SageMaker-Custom-Attributesヘッダーに設定された情報を推論コード内で取得する方法が明確に記載されていました。

Pytorchのドキュメントにも書いて欲しいですね...。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?