はじめに
SageMakerのリアルタイム推論エンドポイント(フレームワーク:Pytorch)に対して推論リクエストを行う際に、Content-Typeヘッダーやリクエストボディのような必須の情報とは別に、追加の情報(以下カスタム属性と呼びます)を付与してリクエストしたいという場面がありました。
SageMakerのリアルタイム推論エンドポイント(フレームワーク:Pytorch)にカスタム属性を渡す方法について、AWSの公式ドキュメントには明確な記載がないのですが、色々試行錯誤しているうちに渡すことができたので、まとめます。
結論
クライアント側では、以下のようにリクエストします。
import boto3
sagemaker_runtime_client = boto3.client("sagemaker-runtime")
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName="<リアルタイム推論エンドポイントの名前>",
ContentType="<リクエストボディのMIMEタイプ>",
Body=<リクエストボディに格納するデータ>,
Accept="<レスポンスボディのMIMEタイプ>",
CustomAttributes=<カスタム属性> # カスタム属性として何かしらのデータを渡す
)
リアルタイム推論エンドポイント側では、推論コードの関数(model_fn, input_fn, predict_fn, output_fnのどれでもOK)内で、以下のようにカスタム属性を取得します。
# input_fnの例
def input_fn(request_body, request_content_type, context):
# 推論リクエストの全ヘッダー情報を辞書型で取得する
request_header_dir = context.get_all_request_header(0)
# カスタム属性を取得する
custom_attributes = request_header_dir["X-Amzn-SageMaker-Custom-Attributes"]
前提
SageMakerのリアルタイム推論エンドポイント(フレームワーク:Pytorch)は、
AWS Deep Learning Containersで提供されているコンテナイメージを使用して立てています。
具体的には、Boto3で以下のようなコードを書いてデプロイしています。
import boto3
from sagemaker import image_uris
aws_region = "ap-northeast-1"
# AWS Deep Learning Containersで提供されているコンテナイメージのURIを取得する
container = image_uris.retrieve(
region=aws_region,
framework="pytorch",
version="1.12.1",
image_scope="inference",
instance_type="ml.g4dn.xlarge",
py_version="py38"
)
sagemaker_client = boto3.client("sagemaker", region_name=aws_region)
# コンテナイメージのURIを指定してSageMakerモデルを作成する
sagemaker_client.create_model(
ModelName = "<作成するSageMakerモデルの名前(任意)>",
ExecutionRoleArn = "<SageMakerの実行ロールのARN>",
PrimaryContainer = {
"Image": container,
"ModelDataUrl": "<S3にアップロードしたモデルアーティファクトのURL>"
}
)
# 以下、エンドポイント設定を作成し、エンドポイントを立てる(省略)
公式ドキュメントに書かれていること
クライアント側について
Amazon SageMaker の InvokeEndpoint API
Amazon SageMakerのInvokeEndpoint API にはX-Amzn-SageMaker-Custom-Attributes
ヘッダーを設定することができ、推論のリクエストに関する追加情報を渡すことができます。
Boto3 の invoke_endpoint メソッド
Boto3 の invoke_endpoint メソッドには、上記 InvokeEndpoint API のX-Amzn-SageMaker-Custom-Attributes
ヘッダーに対応するCustomAttributes
という引数が用意されています。
リアルタイム推論エンドポイント側について
Amazon SageMaker Python SDK の Pytorch モデルのデプロイに関するドキュメントには、X-Amzn-SageMaker-Custom-Attributes
ヘッダーに設定された情報を取得するための方法等は記載されていません。
一方で、推論コード内に定義する関数(model_fn, input_fn, predict_fn, output_fn)でcontext
というオプションの引数を受け取ることができると記載されています。
context
の詳細については、「serve(PytorchでモデルサービングするためのPythonパッケージ)のContext
クラスのコードを読め」と記載されています。
serveパッケージのContext
クラスのコードをよく読んでみると、get_all_request_header()
というそれっぽい名前のメソッドがありました。
get_all_request_header()
はint
型の引数を1つ受け取るようなので、とりあえず0
を渡してみると、期待するもの(推論リクエストの全ヘッダーの情報)がたまたま取得できた、というオチです。
余談
Amazon SageMaker Python SDK の TensorFlow モデルのデプロイに関するドキュメントには、X-Amzn-SageMaker-Custom-Attributes
ヘッダーに設定された情報を推論コード内で取得する方法が明確に記載されていました。
Pytorchのドキュメントにも書いて欲しいですね...。