どうもこんにちは。
今回は、以下の記事でトレーニングしたBERTモデルをRailsアプリケーションで呼び出してみました。
前提条件
上の記事で手順12まで実行できていることを前提として進めます。
SageMakerノートブックインスタンス側での設定
1. トレーニングしたモデルをデプロイ
SageMakerのノートブックインスタンス内で以下のコードを実行します。
predictor = huggingface_estimator.deploy(
initial_instance_count=1,
instance_type='ml.g4dn.xlarge',
endpoint_name='endpoint-01' # ここは任意の名前
)
ここで発行したエンドポイントは、「リアルタイム推論」をするためのエンドポイントとなります。
「サーバレス推論」をするためのエンドポイントを発行する手順は後日記事にします。
Railsの設定
今回は、バッチ処理で実装したためその手順を記載します。
1. Gemのインストール
Gemfileに以下を記述し、bundle install
を実行します。
gem 'aws-sdk-sagemakerruntime'
2. バッチ処理用スクリプトファイルを用意
lib/batch
ディレクトリに、bert_analytics.rb
という名前でファイルを作成します。
3. 環境変数を定義
config/settings/development.yml
に環境変数を定義します。
aws:
sagemaker:
access_key: 'AWSアカウントのアクセスキー'
secret_access_key: 'AWSアカウントのシークレットアクセスキー'
region: 'AWSのリージョン'
s3_bucket: 'S3バケットの名前'
4. bert_analytics.rb
にスクリプト記述
以下のようにスクリプトを記述します。
class Batch::BertAnalytics
def self.bert_analytics
# SageMakerエンドポイント接続
access_key = Settings.aws.sagemaker.access_key
secret_access_key = Settings.aws.sagemaker.secret_access_key
region_name = Settings.aws.sagemaker.region
credentials = Aws::Credentials.new(access_key, secret_access_key)
sagemaker_client = Aws::SageMakerRuntime::Client.new(region: region_name, credentials: credentials)
# エンドポイントを使用して推論
text_comment = 'ここに任意のコメントを入力してください。'
comment_text = text_comment.length > 700 ? text_comment[0...700] : text_comment
input_data = { text: comment_text }.to_json
response = sagemaker_runtime_client.invoke_endpoint({
endpoint_name: 'endpoint-01',
body: input_data
})
result = JSON.parse(response.body.read)
Rails.logger.info result['label'].sub('LABEL_', '').to_i
end
end
710文字以上のテキストを推論できなかったため、テキストの最初の700文字を推論に使用するように処理してます。
5. 処理実行
以下のコマンドをターミナルで実行すると、任意のタイミングで処理を実行できます。
rails runner Batch::BertAnalytics.bert_analytics
この後の処理
バッチ処理のスケジューリングは、以下の記事で説明しています。
まとめ
エンドポイントを使用して推論するのは比較的簡単でした。