本記事でやること
- 学習したモデルに対してのエンドポイントを作成する。
- AWS SageMakerのエンドポイントへアクセスするためのlambda関数を作成する。
- アプリケーションのフロントエンドからリクエストを受け取るためのAPI Gatewayを作成する。(上記で作成したlambda関数を呼びだします)
- アプリケーションのフロントエンドをFlaskで作成する。
今回使用するAWSのサービスとアーキテクチャは以下のようにになります。
githubに今回書いたコードをあげておりますので適宜参照していただければと思います。READMEにも手順を記載しています。
本記事でやらないことは以下の通りなので、他の記事を参照してください。
対象読者
- SageMakerを使った簡単な機械学習アプリケーションを作ってみたい方
- API-Gateway / Lambdaを少し触ったことのある方
使用言語
- Python 3.6.3
- (簡単なHTML)
学習したモデルに対してのエンドポイントを作成する
前回の記事を参考にしてモデルのデプロイまでは行なっていることを前提にして進めます。(以下の画面の様にモデルが作成されて入ればokです)
以下のコードでは、エンドポイントを作成する前にエンドポイントの構成も作成しております。
公式のページにも記載してありますが、エンドポイントの構成を作成する時には、使用するモデル名とホストするコンピューティングインスタンスを指定します。
また順序として、「エンドポイントの構成の作成 -> エンドポイントの作成」となります。
また、エンドポイントの作成までには時間がかかるのでwaiter
を作成してステータスを監視します。
import logging
from boto3.session import Session
class EndPointCreator:
def __init__(self):
self.sagemaker_client = Session(profile_name="your_profile").\
client("sagemaker", region_name="ap-northeast-1")
def execute(self):
model_name = self.get_model_name()
self.create_end_point_config(model_name)
self.create_end_point()
self.wait_end_point(max_iter=120)
def get_model_name(self):
model_name = self.sagemaker_client.list_models(
NameContains="your_model_name",
SortOrder='Descending',
SortBy='CreationTime'
)["Models"][0]["ModelName"]
return model_name
def create_end_point_config(self, model_name):
self.sagemaker_client.create_endpoint_config(
EndpointConfigName="your_endpoint_config_name",
ProductionVariants=[
{
'VariantName': 'hoge',
'ModelName': model_name,
'InitialInstanceCount': 1,
'InstanceType': 'ml.m4.xlarge'
}
]
)
def create_end_point(self):
self.sagemaker_client.create_endpoint(
EndpointName="your_endpoint_name",
EndpointConfigName="your_endpoint_config_name"
)
def wait_end_point(self, max_iter=120):
waiter = self.sagemaker_client.get_waiter("endpoint_in_service")
logging.info("polling start")
waiter.wait(
EndpointName="your_endpoint_name",
WaiterConfig={"MaxAttempts": max_iter}
)
logging.info("polling end")
res = self.sagemaker_client.describe_endpoint(EndpointName="your_endpoint_name")
status = res['EndpointStatus']
if status != 'InService':
message = self.sagemaker_client.describe_endpoint(EndpointName="your_endpoint_name")['FailureReason']
print('Training failed with the following error: {}'.format(message))
raise Exception('Endpoint creation did not succeed')
return status
if __name__ == '__main__':
EndPointCreator().execute()
AWS SageMakerのエンドポイントへアクセスするためのlambda関数を作成する
以下、赤枠の様にlambdaからSageMakerのエンドポイントへアクセスし予測結果を受け取るためのlambda関数を作成します。
以下のコードでは、boto3のinvoke_endpoint
メソッドを使って入力データと供にエンドポイントへアクセスし予測結果を受け取ります。(入力データの渡し方があまりイケテいない気がする...)
import logging
import boto3
LOGGER = logging.getLogger()
LOGGER.setLevel(logging.INFO)
ENDPOINT_NAME = "sagemaker-dev-endpoint"
def lambda_handler(event, _context):
client = boto3.client("sagemaker-runtime", region_name="ap-northeast-1")
values = list(event.values())
# sagemakerのエンドポイントにアクセスし予測結果を受け取る
response = client.invoke_endpoint(
EndpointName=ENDPOINT_NAME,
Body='{0}, {1}, {2}, {3}'.format(values[0], values[1], values[2], values[3]),
ContentType='text/csv',
Accept='application/json'
)
result = response['Body'].read().decode()
return {
'statusCode': 200,
'body': result
}
アプリケーションのフロントエンドからリクエストを受け取るためのAPI Gatewayを作成する
次項で説明するFlaskで作成したフロントエンドからリクエストを受け取るためのAPI Gatewayを作成します。
また、このAPI Gatewayは上記で作成したlambda関数と紐付けてeventを送信する様にします。
コンソール画面からボタンをポチポチするだけで作成できるのですが、再現性を担保すること / 簡単に削除できることを踏まえてAWS Cloudformationを使って構築します。
(コンソール画面から作成したい方は、こちらのサイトが非常に参考になると思います)
API-Keyだけはコンソール画面から取得しておきましょう。
また、以下のコードは上記で作成したlambda関数も一緒に構築しています。
---
# define include macro
AWSTemplateFormatVersion: 2010-09-09
Description: API GateWay environment
# =======set parameters======== #
Parameters:
FunctionName:
Type: String
Description: dev-api
Default: "SageMaker-API"
Runtime:
Description: Language of scripts
Type: String
Default: python3.6
Resources:
# =======IAM======== #
InvokeSageMakerLambdaRole:
Type: AWS::IAM::Role
Properties:
AssumeRolePolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: Allow
Principal:
Service:
- lambda.amazonaws.com
Action:
- sts:AssumeRole
ManagedPolicyArns:
- arn:aws:iam::aws:policy/AWSLambdaFullAccess
- arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
Path: "/service-role/"
# =======lambda======== #
InvokeSageMakerLambda:
Type: AWS::Lambda::Function
Properties:
Code:
ZipFile: !Sub |
import logging
import boto3
LOGGER = logging.getLogger()
LOGGER.setLevel(logging.INFO)
ENDPOINT_NAME = "sagemaker-dev-endpoint"
def lambda_handler(event, _context):
client = boto3.client("sagemaker-runtime", region_name="ap-northeast-1")
values = list(event.values())
response = client.invoke_endpoint(
EndpointName=ENDPOINT_NAME,
Body='{0}, {1}, {2}, {3}'.format(values[0], values[1], values[2], values[3]),
ContentType='text/csv',
Accept='application/json'
)
result = response['Body'].read().decode()
return {
'statusCode': 200,
'body': result
}
Description: "Sagemaker API Invoke"
FunctionName: !Sub "${FunctionName}"
Handler: index.lambda_handler
MemorySize: 128
Role: !GetAtt InvokeSageMakerLambdaRole.Arn
Runtime: !Ref Runtime
Timeout: 15
# =======API Gateway======== #
SageMakerApi:
Type: AWS::ApiGateway::RestApi
Properties:
Name: "SageMakerApi"
Resource:
Type: AWS::ApiGateway::Resource
Properties:
RestApiId: !Ref SageMakerApi
ParentId: !GetAtt SageMakerApi.RootResourceId
PathPart: !Sub "${FunctionName}"
DependsOn: "InvokeSageMakerLambda"
LambdaPermission:
Type: AWS::Lambda::Permission
Properties:
FunctionName: !Sub "${FunctionName}"
Action: "lambda:InvokeFunction"
Principal: "apigateway.amazonaws.com"
DependsOn: "InvokeSageMakerLambda"
ResourceMethod:
Type: AWS::ApiGateway::Method
Properties:
RestApiId: !Ref SageMakerApi
ResourceId: !Ref Resource
AuthorizationType: "None"
HttpMethod: "POST"
Integration:
Type: "AWS"
IntegrationHttpMethod: "POST"
Uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/arn:aws:lambda:${AWS::Region}:${AWS::AccountId}:function:${FunctionName}/invocations"
IntegrationResponses:
- StatusCode: 200
MethodResponses:
- StatusCode: 200
DependsOn: "LambdaPermission"
ApiGatewayDeployment:
Type: AWS::ApiGateway::Deployment
DependsOn: "ResourceMethod"
Properties:
RestApiId: !Ref SageMakerApi
StageName: "dev"
アプリケーションのフロントエンドをFlaskで作成する
予測結果を返すためのインプットデータを入力するフォームと予測結果を出力する画面をFlaskと簡単なHTMLを使って作成します。(フォームから入力された値を受け取り、API Gatewayにリクエストをします。そして、そのレスポンスを出力します)
入力フォームは、以下の様な簡単なものを作成します。各フォームに数値を入力してsubmitボタンを押下すると、分類結果が画面に表示される様にします。
アプリケーションに必要なファイルとディレクトリ構成は以下の様になっています。
.
├── apps.py
│
├── lib
│ ├── api-key
│ │ └── api-gateway.yml
│ └── categorical_classifier.py
│
└── templates
├── index.html
└── prediction.html
templates/index.html
は、入力フォームの画面を表示するhtmlファイルになります。また、templates/prediction.html
は予測結果を表示するhtmlファイルになります。
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<title>demo-app</title>
</head>
<body>
<form action="/" method="POST">
<div style=font-size:15px; font-weight:bold; mergin-left:150px;>
<p>Sepal Length<br><input type="text" name="sepal_length"></p>
<p>Sepal Width<br><input type="text" name="sepal_width"></p>
<p>Petal Length<br><input type="text" name="petal_length"></p>
<p>Petal Width<br><input type="text" name="petal_width"></p>
<p><input type="submit" value="submit"></p>
</div>
</form>
</body>
<html>
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<title>ようこそ</title>
</head>
<body>
<div style=font-size:15px; font-weight:bold; mergin-left:150px;>
{% if result %}
<p>{{ result }}</p>
{% else %}
<p>Hello World</p>
{% endif %}
</div>
</body>
</html>
上記のhtmlファイルに値をレンダリングしたり、入力された値を受け取るためのファイルがapps.py
になります。
やっていることは簡単で
-
http://localhost:5000/
へGET
でアクセスした場合は、入力フォームがあるindex.html
を表示します。 - 入力フォームに値を入れて
submit
ボタンを押下しPOST
した場合は、API-Gatewayにリクエストを送り予測結果を受け取ります。そして、その値をprediction.html
のresult
にレンダリングし表示します。
from flask import Flask, render_template, request
from lib.categorical_classifier import categorical_classifier
app = Flask(__name__)
@app.route("/", methods=["GET", "POST"])
def index():
# 入力フォームに値を入れ、submitボタンを押下しpostした時の処理
if request.method == "POST":
sepal_length = float(request.form["sepal_length"])
sepal_width = float(request.form["sepal_width"])
petal_length = float(request.form["petal_length"])
petal_width = float(request.form["petal_width"])
data = {
"sepal_length": sepal_length,
"sepal_width": sepal_width,
"petal_length": petal_length,
"petal_width": petal_width
}
# 予測結果を受け取る処理はメソッド化しています。
responses = categorical_classifier(data)
result = responses["body"]
return render_template("prediction.html", result=result)
elif request.method == "GET":
return render_template("index.html")
if __name__ == '__main__':
app.debug = True
app.run(host='0.0.0.0')
予測結果を返す処理はcategorical_classifier.py
でメソッド化をしております。
フォームから入力された値を元にAPI-Gatewayで設定されたURLへポストしているだけです。
その際、API-Gatewayで作成したAPI-Keyをapi-key/api-gateway.yml
から取得しheaderとして渡す必要があります。
'x-api-key': 'your-api-key'
import os
import requests
import json
import yaml
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
API_KEY_FILE = os.path.join(BASE_DIR, "./api-key/api-gateway.yml")
API_URL = 'https://your_api_url'
def categorical_classifier(data: dict) -> dict:
"""
:param data={
"sepal length": ,
"sepla width" : ,
"petal length" : ,
"petal width" :
}
:return: {'statusCode': , 'body':}
"""
with open(API_KEY_FILE, mode="r") as file:
header = yaml.load(file)
response = requests.post(API_URL, headers=header, data=json.dumps(data)).json()
return response
おわりに
AWS SageMakerのAPIを使ったアプリケーションとしては、鉄板なアーキテクチャーな気がします。すごく簡単に作成することができました。
フロントエンドのHTMLはもう少し綺麗に書く必要がありますし入力フォームのバリデーションもやらないといけないですが、今回はモックということであまり気にせず作成しました。
今回作成したアプリケーションのコードはgithubにアップしておりますのでそちらも参照して頂ければと思います。