117
105

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

AWS SageMakerとAPI Gateway、Lambdaで作るサーバレスな機械学習アプリケーション

Posted at

本記事でやること

  1. 学習したモデルに対してのエンドポイントを作成する。
  2. AWS SageMakerのエンドポイントへアクセスするためのlambda関数を作成する。
  3. アプリケーションのフロントエンドからリクエストを受け取るためのAPI Gatewayを作成する。(上記で作成したlambda関数を呼びだします)
  4. アプリケーションのフロントエンドをFlaskで作成する。

今回使用するAWSのサービスとアーキテクチャは以下のようにになります。

スクリーンショット 2019-02-14 10.07.39.png

githubに今回書いたコードをあげておりますので適宜参照していただければと思います。READMEにも手順を記載しています。

本記事でやらないことは以下の通りなので、他の記事を参照してください。

  • 前の記事で行なったSageMakerトレーニングジョブの実行方法
  • 前の記事で行なったモデルのデプロイ方法
  • 各種サービスのIAMの設定

対象読者

  • SageMakerを使った簡単な機械学習アプリケーションを作ってみたい方
  • API-Gateway / Lambdaを少し触ったことのある方

使用言語

  • Python 3.6.3
  • (簡単なHTML)

学習したモデルに対してのエンドポイントを作成する

前回の記事を参考にしてモデルのデプロイまでは行なっていることを前提にして進めます。(以下の画面の様にモデルが作成されて入ればokです)

スクリーンショット 2019-02-15 9.07.09.png

以下のコードでは、エンドポイントを作成する前にエンドポイントの構成も作成しております。
公式のページにも記載してありますが、エンドポイントの構成を作成する時には、使用するモデル名とホストするコンピューティングインスタンスを指定します。
また順序として、「エンドポイントの構成の作成 -> エンドポイントの作成」となります。

また、エンドポイントの作成までには時間がかかるのでwaiterを作成してステータスを監視します。

import logging
from boto3.session import Session


class EndPointCreator:
    def __init__(self):
        self.sagemaker_client = Session(profile_name="your_profile").\
            client("sagemaker", region_name="ap-northeast-1")

    def execute(self):
        model_name = self.get_model_name()
        self.create_end_point_config(model_name)
        self.create_end_point()
        self.wait_end_point(max_iter=120)

    def get_model_name(self):
        model_name = self.sagemaker_client.list_models(
            NameContains="your_model_name",
            SortOrder='Descending',
            SortBy='CreationTime'
        )["Models"][0]["ModelName"]

        return model_name

    def create_end_point_config(self, model_name):
        self.sagemaker_client.create_endpoint_config(
            EndpointConfigName="your_endpoint_config_name",
            ProductionVariants=[
                {
                    'VariantName': 'hoge',
                    'ModelName': model_name,
                    'InitialInstanceCount': 1,
                    'InstanceType': 'ml.m4.xlarge'
                }
            ]
        )

    def create_end_point(self):
        self.sagemaker_client.create_endpoint(
            EndpointName="your_endpoint_name",
            EndpointConfigName="your_endpoint_config_name"
        )

    def wait_end_point(self, max_iter=120):
        waiter = self.sagemaker_client.get_waiter("endpoint_in_service")
        logging.info("polling start")

        waiter.wait(
            EndpointName="your_endpoint_name",
            WaiterConfig={"MaxAttempts": max_iter}
        )
        logging.info("polling end")

        res = self.sagemaker_client.describe_endpoint(EndpointName="your_endpoint_name")
        status = res['EndpointStatus']

        if status != 'InService':
            message = self.sagemaker_client.describe_endpoint(EndpointName="your_endpoint_name")['FailureReason']
            print('Training failed with the following error: {}'.format(message))
            raise Exception('Endpoint creation did not succeed')

        return status


if __name__ == '__main__':
    EndPointCreator().execute()

AWS SageMakerのエンドポイントへアクセスするためのlambda関数を作成する

以下、赤枠の様にlambdaからSageMakerのエンドポイントへアクセスし予測結果を受け取るためのlambda関数を作成します。

スクリーンショット 2019-02-14 10.07.39.png

以下のコードでは、boto3のinvoke_endpointメソッドを使って入力データと供にエンドポイントへアクセスし予測結果を受け取ります。(入力データの渡し方があまりイケテいない気がする...)

lambda_function.py

import logging
import boto3

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.INFO)

ENDPOINT_NAME = "sagemaker-dev-endpoint"


def lambda_handler(event, _context):
    client = boto3.client("sagemaker-runtime", region_name="ap-northeast-1")
    values = list(event.values())
    
    # sagemakerのエンドポイントにアクセスし予測結果を受け取る
    response = client.invoke_endpoint(
        EndpointName=ENDPOINT_NAME,
        Body='{0}, {1}, {2}, {3}'.format(values[0], values[1], values[2], values[3]),
        ContentType='text/csv',
        Accept='application/json'
    )

    result = response['Body'].read().decode()

    return {
        'statusCode': 200,
        'body': result
    }

アプリケーションのフロントエンドからリクエストを受け取るためのAPI Gatewayを作成する

次項で説明するFlaskで作成したフロントエンドからリクエストを受け取るためのAPI Gatewayを作成します。
また、このAPI Gatewayは上記で作成したlambda関数と紐付けてeventを送信する様にします。

スクリーンショット 2019-02-14 10.07.39.png

コンソール画面からボタンをポチポチするだけで作成できるのですが、再現性を担保すること / 簡単に削除できることを踏まえてAWS Cloudformationを使って構築します。
(コンソール画面から作成したい方は、こちらのサイトが非常に参考になると思います)
API-Keyだけはコンソール画面から取得しておきましょう。

また、以下のコードは上記で作成したlambda関数も一緒に構築しています。

cloudformation.yml
---
# define include macro


AWSTemplateFormatVersion: 2010-09-09
Description: API GateWay environment
# =======set parameters======== #
Parameters:
  FunctionName:
    Type: String
    Description: dev-api
    Default: "SageMaker-API"
  Runtime:
    Description: Language of scripts
    Type: String
    Default: python3.6

Resources:
  # =======IAM======== #
  InvokeSageMakerLambdaRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Version: "2012-10-17"
        Statement:
          - Effect: Allow
            Principal:
              Service:
                - lambda.amazonaws.com
            Action:
              - sts:AssumeRole
      ManagedPolicyArns:
        - arn:aws:iam::aws:policy/AWSLambdaFullAccess
        - arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
      Path: "/service-role/"

  # =======lambda======== #
  InvokeSageMakerLambda:
    Type: AWS::Lambda::Function
    Properties:
      Code:
        ZipFile: !Sub |
          import logging
          import boto3

          LOGGER = logging.getLogger()
          LOGGER.setLevel(logging.INFO)

          ENDPOINT_NAME = "sagemaker-dev-endpoint"


          def lambda_handler(event, _context):
              client = boto3.client("sagemaker-runtime", region_name="ap-northeast-1")
              values = list(event.values())

              response = client.invoke_endpoint(
                  EndpointName=ENDPOINT_NAME,
                  Body='{0}, {1}, {2}, {3}'.format(values[0], values[1], values[2], values[3]),
                  ContentType='text/csv',
                  Accept='application/json'
              )

              result = response['Body'].read().decode()

              return {
                  'statusCode': 200,
                  'body': result
              }
      Description: "Sagemaker API Invoke"
      FunctionName: !Sub "${FunctionName}"
      Handler: index.lambda_handler
      MemorySize: 128
      Role: !GetAtt InvokeSageMakerLambdaRole.Arn
      Runtime: !Ref Runtime
      Timeout: 15

  # =======API Gateway======== #
  SageMakerApi:
    Type: AWS::ApiGateway::RestApi
    Properties:
      Name: "SageMakerApi"
  Resource:
    Type: AWS::ApiGateway::Resource
    Properties:
      RestApiId: !Ref SageMakerApi
      ParentId: !GetAtt SageMakerApi.RootResourceId
      PathPart: !Sub "${FunctionName}"
    DependsOn: "InvokeSageMakerLambda"
  LambdaPermission:
    Type: AWS::Lambda::Permission
    Properties:
      FunctionName: !Sub "${FunctionName}"
      Action: "lambda:InvokeFunction"
      Principal: "apigateway.amazonaws.com"
    DependsOn: "InvokeSageMakerLambda"
  ResourceMethod:
    Type: AWS::ApiGateway::Method
    Properties:
      RestApiId: !Ref SageMakerApi
      ResourceId: !Ref Resource
      AuthorizationType: "None"
      HttpMethod: "POST"
      Integration:
        Type: "AWS"
        IntegrationHttpMethod: "POST"
        Uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/arn:aws:lambda:${AWS::Region}:${AWS::AccountId}:function:${FunctionName}/invocations"
        IntegrationResponses:
        - StatusCode: 200
      MethodResponses:
      - StatusCode: 200
    DependsOn: "LambdaPermission"

  ApiGatewayDeployment:
    Type: AWS::ApiGateway::Deployment
    DependsOn: "ResourceMethod"
    Properties:
      RestApiId: !Ref SageMakerApi
      StageName: "dev"

アプリケーションのフロントエンドをFlaskで作成する

予測結果を返すためのインプットデータを入力するフォームと予測結果を出力する画面をFlaskと簡単なHTMLを使って作成します。(フォームから入力された値を受け取り、API Gatewayにリクエストをします。そして、そのレスポンスを出力します)

スクリーンショット 2019-02-14 10.07.39.png

入力フォームは、以下の様な簡単なものを作成します。各フォームに数値を入力してsubmitボタンを押下すると、分類結果が画面に表示される様にします。

スクリーンショット 2019-02-16 18.06.21.png スクリーンショット 2019-02-16 18.45.56.png

アプリケーションに必要なファイルとディレクトリ構成は以下の様になっています。

.
├── apps.py
│
├── lib
│   ├── api-key
│   │   └── api-gateway.yml
│   └── categorical_classifier.py
│
└── templates
    ├── index.html
    └── prediction.html

templates/index.htmlは、入力フォームの画面を表示するhtmlファイルになります。また、templates/prediction.htmlは予測結果を表示するhtmlファイルになります。

index.html
<!DOCTYPE html>
<html lang="ja">
    <head>
        <meta charset="UTF-8">
        <title>demo-app</title>
    </head>
    <body>
        <form action="/" method="POST">
            <div style=font-size:15px; font-weight:bold; mergin-left:150px;>
                <p>Sepal Length<br><input type="text" name="sepal_length"></p>
                <p>Sepal Width<br><input type="text" name="sepal_width"></p>
                <p>Petal Length<br><input type="text" name="petal_length"></p>
                <p>Petal Width<br><input type="text" name="petal_width"></p>
                <p><input type="submit" value="submit"></p>
            </div>
        </form>
    </body>
<html>

prediction.html
<!DOCTYPE html>
<html lang="ja">
<head>
    <meta charset="UTF-8">
    <title>ようこそ</title>
</head>
<body>
    <div style=font-size:15px; font-weight:bold; mergin-left:150px;>
        {% if result %}
            <p>{{ result }}</p>
        {% else %}
            <p>Hello World</p>
        {% endif %}
    </div>
</body>
</html>

上記のhtmlファイルに値をレンダリングしたり、入力された値を受け取るためのファイルがapps.pyになります。

やっていることは簡単で

  • http://localhost:5000/GETでアクセスした場合は、入力フォームがあるindex.htmlを表示します。
  • 入力フォームに値を入れてsubmitボタンを押下しPOSTした場合は、API-Gatewayにリクエストを送り予測結果を受け取ります。そして、その値をprediction.htmlresultにレンダリングし表示します。
apps.py

from flask import Flask, render_template, request

from lib.categorical_classifier import categorical_classifier

app = Flask(__name__)


@app.route("/", methods=["GET", "POST"])
def index():
    # 入力フォームに値を入れ、submitボタンを押下しpostした時の処理
    if request.method == "POST":
        sepal_length = float(request.form["sepal_length"])
        sepal_width = float(request.form["sepal_width"])
        petal_length = float(request.form["petal_length"])
        petal_width = float(request.form["petal_width"])

        data = {
            "sepal_length": sepal_length,
            "sepal_width": sepal_width,
            "petal_length": petal_length,
            "petal_width": petal_width
        }
        # 予測結果を受け取る処理はメソッド化しています。
        responses = categorical_classifier(data)
        result = responses["body"]

        return render_template("prediction.html", result=result)

    elif request.method == "GET":
        return render_template("index.html")


if __name__ == '__main__':
    app.debug = True
    app.run(host='0.0.0.0')

予測結果を返す処理はcategorical_classifier.pyでメソッド化をしております。
フォームから入力された値を元にAPI-Gatewayで設定されたURLへポストしているだけです。
その際、API-Gatewayで作成したAPI-Keyをapi-key/api-gateway.ymlから取得しheaderとして渡す必要があります。

api-gateway.yml
'x-api-key': 'your-api-key'
categorical_classifier.py
import os
import requests
import json
import yaml

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
API_KEY_FILE = os.path.join(BASE_DIR, "./api-key/api-gateway.yml")
API_URL = 'https://your_api_url'

def categorical_classifier(data: dict) -> dict:
    """
    :param data={
    "sepal length": ,
    "sepla width" : ,
    "petal length" : ,
    "petal width" :
    }

    :return: {'statusCode': , 'body':}
    """
    with open(API_KEY_FILE, mode="r") as file:
        header = yaml.load(file)

    response = requests.post(API_URL, headers=header, data=json.dumps(data)).json()

    return response

おわりに

AWS SageMakerのAPIを使ったアプリケーションとしては、鉄板なアーキテクチャーな気がします。すごく簡単に作成することができました。
フロントエンドのHTMLはもう少し綺麗に書く必要がありますし入力フォームのバリデーションもやらないといけないですが、今回はモックということであまり気にせず作成しました。

今回作成したアプリケーションのコードはgithubにアップしておりますのでそちらも参照して頂ければと思います。

117
105
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
117
105

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?