6
3

Amazon Bedrock のバッチ推論が GA したので試してみた

Last updated at Posted at 2024-08-22

Amazon Bedrock のバッチ推論が GA したので早速試してみました。バッチ推論だと 50% の料金なのは嬉しいです。

利用できるモデルの一覧はこちら。

クォータはこちら。

バッチ推論

以下の検証では Amazon Bedrock への権限が付与された SageMaker Notebook で行っております。

最新の Boto3 を利用します。

!pip install -qU boto3

必要なクライアントを作成します。

import boto3
import sagemaker

s3 = boto3.client('s3')
bedrock = boto3.client(service_name="bedrock")
bedrock_runtime = boto3.client("bedrock-runtime")

# default bucket
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()

推論は boto3 の bedrock_runtime の invoke_model のパラメータを modelInput として設定します。recordId は設定しない場合自動で補完してくれるそうです。

気をつけるべきクォータとしてはバッチ推論に必要な最小レコード数が 1000 件なので検証ではそれ以上のレコードを入力する必要があります。

import json
import base64
import httpx

# Open image
image1_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg"
image1_media_type = "image/jpeg"
image1_data = base64.b64encode(httpx.get(image1_url).content).decode("utf-8")

contents = [
    {
        # "recordId": "1",
        "modelInput": {
            "anthropic_version": "bedrock-2023-05-31",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "あなたは誰ですか?"}
                    ]
                },
            ],
            "temperature": 0.3,
            "max_tokens": 500,
        }
    },
    {
        # "recordId": "2",
        "modelInput": {
            "anthropic_version": "bedrock-2023-05-31",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": image1_media_type,
                                "data": image1_data
                            }
                        },
                        {"type": "text", "text": "何が写っていますか?"}
                    ]
                },
            ],
            "temperature": 0.3,
            "max_tokens": 500,
        }
    }
]

# レコード数の水増し
contents = contents + [contents[0]] * 1000

念のためパラメータが間違ってないか通常の invoke_model で動作確認します。

# Check if normal request success

for content in contents[:5]:
    response = bedrock_runtime.invoke_model(
        modelId="anthropic.claude-3-haiku-20240307-v1:0",
        body=json.dumps(content["modelInput"]),
    )
    model_response = json.loads(response["body"].read())
    response_text = model_response["content"][0]["text"]
    print(response_text)

問題なければ入力を JSONL 形式に変換し S3 にアップロードしバッチ推論を実行します。

import time

input_key = "batch/input/abc.jsonl"
output_key = "batch/output/"

jsonl_content = "\n".join([json.dumps(record) for record in contents])
s3.put_object(Body=jsonl_content, Bucket=bucket, Key=input_key)

inputDataConfig = ({
    "s3InputDataConfig": {
        "s3Uri": f"s3://{bucket}/{input_key}"
    }
})

outputDataConfig=({
    "s3OutputDataConfig": {
        "s3Uri": f"s3://{bucket}/{output_key}"
    }
})

# date
jobName = "my-batch-job-" + time.strftime("%Y%m%d%H%M%S")
response = bedrock.create_model_invocation_job(
    roleArn=sagemaker.get_execution_role(),
    modelId="anthropic.claude-3-haiku-20240307-v1:0",
    jobName=jobName,
    inputDataConfig=inputDataConfig,
    outputDataConfig=outputDataConfig
)

jobArn = response.get('jobArn')

Job が完了するまでしばらく待ちます。

# check status every 30 sec and stop when status if "Completed" or "Failed"
while True:
    time.sleep(30)
    status = bedrock.get_model_invocation_job(jobIdentifier=jobArn)['status']
    print(status)
    if status in ["Completed", "Failed"]:
        break

Job の詳細を確認します。失敗した際のエラーメッセージなども確認できます。

jobDetail = bedrock.get_model_invocation_job(jobIdentifier=jobArn)
jobDetail

Job の結果を取得します。

# Get Output Path
outputPath = f'{output_key}{jobArn.split("/").pop()}/{input_key.split("/").pop()}.out'
print(outputPath)

# Retrieve output
obj = s3.get_object(Bucket=bucket, Key=outputPath)
obj_decoded = obj["Body"].read().decode("utf-8")
results = [json.loads(row) for row in obj_decoded.split('\n') if row]
outputs = [row["modelOutput"]["content"] for row in results]
for output in outputs[:5]:
    print(output)

正しく処理されていることが確認できます。画像入力にも対応できているようです。

# output
[{'type': 'text', 'text': '私は人工知能のアシスタントです。人間の皆さまを助けるために作られました。会話を通して、様々な質問に答えたり、課題を一緒に解決したりすることができます。どのようなお手伝いができるでしょうか?'}]
[{'type': 'text', 'text': 'この画像には黒い色をした小さな昆虫が写っています。その昆虫は地面に座っており、細長い脚と触角を持っています。背景は暗めの色調で、昆虫の細部がよく見えるように撮影されています。この昆虫は恐らくアリやハチなどの小型の昆虫の一種だと思われます。'}]
[{'type': 'text', 'text': '私は人工知能のAssistantです。人間のようには見えませんが、人間とコミュニケーションを取ることができます。私の役割は、人間の皆さまの様々な質問に答えたり、課題を一緒に解決したりすることです。どのようなお手伝いができるでしょうか?'}]
[{'type': 'text', 'text': '私は人工知能のAssistantです。人間のようには話せませんが、できる限り皆さんのお役に立てるよう頑張っています。何か質問やお手伝いできることはありますか?'}]
[{'type': 'text', 'text': '私は人工知能のアシスタントです。人間のようには見えませんが、人間とコミュニケーションを取ることができます。皆さまのお役に立てるよう、できる限りの対応をさせていただきます。どのようなお話しができればよろしいでしょうか?'}]

まとめ

LLM のバッチ推論はベンチマークやデータ作成、評価など活用できるところが多いので、バッチ推論でコストが半分になるのは嬉しいですね。いろいろ試していきたいです。

6
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
3