Why not login to Qiita and try out its useful features?

We'll deliver articles that match you.

You can read useful information later.

3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【Rails,AWS】SageMakerのエンドポイントをRailsで使用してみる

Last updated at Posted at 2024-04-11

どうもこんにちは。

今回は、以下の記事でトレーニングしたBERTモデルをRailsアプリケーションで呼び出してみました。

前提条件

上の記事で手順12まで実行できていることを前提として進めます。

SageMakerノートブックインスタンス側での設定

1. トレーニングしたモデルをデプロイ

SageMakerのノートブックインスタンス内で以下のコードを実行します。

predictor = huggingface_estimator.deploy(
    initial_instance_count=1,
    instance_type='ml.g4dn.xlarge',
    endpoint_name='endpoint-01' # ここは任意の名前
)

ここで発行したエンドポイントは、「リアルタイム推論」をするためのエンドポイントとなります。
「サーバレス推論」をするためのエンドポイントを発行する手順は後日記事にします。

Railsの設定

今回は、バッチ処理で実装したためその手順を記載します。

1. Gemのインストール

Gemfileに以下を記述し、bundle installを実行します。

gem 'aws-sdk-sagemakerruntime'

2. バッチ処理用スクリプトファイルを用意

lib/batchディレクトリに、bert_analytics.rbという名前でファイルを作成します。

3. 環境変数を定義

config/settings/development.ymlに環境変数を定義します。

aws:
  sagemaker:
    access_key: 'AWSアカウントのアクセスキー'
    secret_access_key: 'AWSアカウントのシークレットアクセスキー'
    region: 'AWSのリージョン'
    s3_bucket: 'S3バケットの名前'

4. bert_analytics.rbにスクリプト記述

以下のようにスクリプトを記述します。

class Batch::BertAnalytics
    def self.bert_analytics
        # SageMakerエンドポイント接続
        access_key = Settings.aws.sagemaker.access_key
        secret_access_key = Settings.aws.sagemaker.secret_access_key
        region_name = Settings.aws.sagemaker.region
        credentials = Aws::Credentials.new(access_key, secret_access_key)
        sagemaker_client = Aws::SageMakerRuntime::Client.new(region: region_name, credentials: credentials)

        # エンドポイントを使用して推論
        text_comment = 'ここに任意のコメントを入力してください。'
        comment_text = text_comment.length > 700 ? text_comment[0...700] : text_comment
        input_data = { text: comment_text }.to_json
        response = sagemaker_runtime_client.invoke_endpoint({
                                                            endpoint_name: 'endpoint-01',
                                                            body: input_data
                                                          })
        
        result = JSON.parse(response.body.read)
        Rails.logger.info result['label'].sub('LABEL_', '').to_i
    end
end

710文字以上のテキストを推論できなかったため、テキストの最初の700文字を推論に使用するように処理してます。

5. 処理実行

以下のコマンドをターミナルで実行すると、任意のタイミングで処理を実行できます。

rails runner Batch::BertAnalytics.bert_analytics

この後の処理

バッチ処理のスケジューリングは、以下の記事で説明しています。

まとめ

エンドポイントを使用して推論するのは比較的簡単でした。

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?