LoginSignup
3

More than 1 year has passed since last update.

posted at

updated at

Organization

GPUを用いる機械学習推論処理をAWS BatchのSpotインスタンスで実現する

概要

AWSにて機械学習を行うにはまずSageMakerがあります。SageMakerはモデルの学習・管理・デプロイを取りまとめて行えるフルマネージドサービスです。
概ね便利なサービスなのですが、機械学習処理で用いるGPUインスタンスは高額なのでスポットインスタンスでコスト軽減を行いたくなります。
モデルの学習にはスポットインスタンスが使えるのですが、バッチ変換ジョブにはスポットインスタンスが使えません。(2019/12現在)
https://docs.aws.amazon.com/ja_jp/sagemaker/latest/dg/model-managed-spot-training.html
https://docs.aws.amazon.com/ja_jp/sagemaker/latest/dg/batch-transform.html

バッチとして非リアルタイムで推論するGPU処理をAWS Batchを使って実装したのでメモとして残します。

ポイント

AWS Batch

AWS Batchとは名前の通り、バッチ処理を大量のリソースで並列に処理できるサービスです
https://docs.aws.amazon.com/ja_jp/batch/latest/userguide/what-is-batch.html

「コンピューティング環境」というのでマシンリソース量の上限・下限を定義します。
リソースの調整はAWS Batchが仕事量に応じて適切に行ってくれます
下限を0にしておけば、使ってるときだけマシンを起動させることができます
スポットインスタンスの利用が可能で、CPUインスタンス・GPUインスタンスを使い分けられます

後述しますが、「配列ジョブ」という機能を使うのがポイントで、これでジョブを分割して並列に実行できます
https://docs.aws.amazon.com/ja_jp/batch/latest/userguide/array_jobs.html

MLOpsフレームワークであるmetaflowなんかはAWS Batch用の機能がありますね
https://docs.metaflow.org/metaflow/scaling#using-aws-batch

ECSイメージの準備

AWS BatchからはECSが呼ばれるため、EC2インスタンスには制御用のECSエージェントがインストールされている必要があります。
またGPUを利用するためにnvidiaドライバも必要です。
双方を満たすためにec2のGPU用AMIを利用します
https://docs.aws.amazon.com/ja_jp/AmazonECS/latest/developerguide/ecs-optimized_AMI.html
https://docs.aws.amazon.com/ja_jp/AmazonECS/latest/developerguide/retrieve-ecs-optimized_AMI.html

  ECSAMI:
    Description: AMI ID
    Type: AWS::SSM::Parameter::Value<AWS::EC2::Image::Id>
    Default: /aws/service/ecs/optimized-ami/amazon-linux-2/gpu/recommended/image_id

ランタイムの指定

dockerコンテナ内からnvidiaのデバイスを使いたいのであれば、dockerの起動オプションにruntime=nvidiaを指定する必要があります。
ECSで実現する場合、自作のイメージを作成して設定ファイルを書き換える、ec2のUserDataから設定ファイルを書き換えるなどの方法があります。
AWS Batchで簡易に実現する方法としてJobDefinitionの制約にGPUを指定すると自動的にruntimeがnvidiaになります
しかしながら、この方法ではリソース制約としてGPU1つを要求するため、GPUが1つだけのマシンであれば1インスタンスで1ジョブしか動かせません。スクリプト内でリソースを使い切れるように並列処理を行うことになると思います。

        ResourceRequirements:
          - Type: GPU
            Value: 1

スポットインスタンス停止対策

スポットインスタンスを利用するのでインスタンスが強制停止される可能性があります。
強制停止される通知を受け取った後に途中結果をどこかに保存して、途中から再開できるような推論処理にするのが最も効率が良いです。
ただ今回は推論処理に手を入れたくなかったため、AWSバッチの配列ジョブにて処理を細かいジョブに分けた上で正常終了しなかったジョブはAWS Batchのリトライ機構でリトライさせます。

 配列ジョブ

AWS Batch側でジョブに数字を割り振り、スクリプトからはその数字を環境変数から読み込み、ジョブごとに別の処理をさせることができます。
細かくジョブを分けることでジョブ単位でリトライができるのと、ジョブ単位で並列処理ができます。
ジョブの粒度を小さくしすぎるとコンテナのイメージダウンロードやECS処理などのオーバーヘッドが大きくなるため、いい塩梅で設定してください。
https://docs.aws.amazon.com/ja_jp/batch/latest/userguide/array_jobs.html

pythonスクリプトからos.getenv('AWS_BATCH_JOB_ARRAY_INDEX')のような形で読めます

リトライ

念の為上限の10回リトライさせています

      RetryStrategy:
        Attempts: 10 # 1-10 allowed

aws batchの呼び出し

airflowからこんな感じのスクリプトでboto3を呼び出しています
arrayPropertiesとして配列ジョブのサイズを指定しています。

def _submit_job(*, target_ymd, array_size):
    batch = session(env).client('batch', region_name=region)
    response = batch.submit_job(
        jobName="",
        jobQueue="",
        jobDefinition="",
        arrayProperties={"size": array_size},
        containerOverrides={
            "command": [
                "python",
                "script.py",
                "--target_date",
                target_ymd,
            ],
            "environment": [
                {
                    "name": "STAGE",
                    "value": "test"
                }
            ]
        },
    )
    job_id = response['jobId']
    print(f'job has been sent. job id: {job_id}')
    return job_id


def _wait_job_finish(job_id):
    batch = session(env).client('batch', region_name=region)
    while True:
        response = batch.describe_jobs(jobs=[job_id])
        status = response['jobs'][0]['status']
        if status in ['SUCCEEDED']:
            return True
        elif status in ['FAILED']:
            return False
        else:
            sleep(10)

呼び出しのlambdaをserverlessから構築してる人もいるようです
https://blog.ikedaosushi.com/entry/2019/04/27/222957

region

作業用インスタンスはネットワークの遅延が小さい東京リージョンを使いがちだと思いますが、バッチジョブにおいては多少ネットワークが遅くても問題ないので最安リージョンで実行できます。
最安リージョンとしては北カリフォルニア以外のアメリカ3リージョン(バージニア、オハイオ、オレゴン)を抑えておけば良さそうです。
リージョン差は結構大きくて、例えばp2.xlargeにおいては0.90$/h(us-east-1)と1.542$/h(ap-northeast-1)と1.6倍ほどの開きがあります。

CloudFormation

CloudFormationにて環境構築をコード化します。下記を参考にしました

AWS CloudFormation のベストプラクティス
https://docs.aws.amazon.com/ja_jp/AWSCloudFormation/latest/UserGuide/best-practices.html

【AWS】CloudFormationの実践的活用~個人的ベストプラクティス
https://qiita.com/tmiki/items/022df525defa82d1785b

CloudFormation Templateの分け方

今回はVPCから環境を構築しました。VPCは別アプリケーションでも使い回す想定で構築するので、AWS Batchのバッチジョブ層とはライフサイクルが異なります。また、ネットワーク周りの環境依存情報(VPC Peeringなど)はわかりやすいよう別にしておきたかったので中間層として切り出しました。合計3つのテンプレートを作ります。

Cloud Formation Template

VPC

2AZ, 3AZ, 4AZのテンプレートは下記サイトにあります。(一番上3つです)
https://templates.cloudonaut.io/en/stable/vpc/
NCALでは制限をかけず、SGでセキュリティを担保する思想のテンプレートなので、NCALを変えたければ改変する必要があります

VPCに関しては繰り返し処理が多くなりがちですのでいろいろいじるのであればcdkで記述する手もありそうです。
https://github.com/aws/aws-cdk

中間層

別リージョンの踏み台サーバからアクセスできるようにする想定で、環境依存の設定を切り出します。
セキュリティグループの設定とVPCピアリングの設定があります。
(ちなみにsshできるようにするのはデバッグ用で、AWS Batchで使う分にはIAMの設定だけで大丈夫です。)
別途踏み台サーバ側のルートテーブルとSG, NACLの設定も必要です

# this template is supposed to use after make vpc layer
---
AWSTemplateFormatVersion: 2010-09-09
Description: Middle Cloud Formation layer depending on environment specific value

Parameters:
  VpcStackName:
    Type: String
  BastionServerPrivateIP:
    Type: String
  BastionServerVpcId:
    Type: String
  BastionServerRegion:
    Type: String

Resources:
  #####
  # Security Group
  #####
  ComputeSecurityGroup:
    Type: AWS::EC2::SecurityGroup
    Properties:
      VpcId:
        Fn::ImportValue: !Sub ${VpcStackName}-VPC
      GroupDescription: Enable SSH access via port 22 from bastion server
      SecurityGroupIngress:
        - CidrIp: !Sub ${BastionServerPrivateIP}/32
          IpProtocol: tcp
          FromPort: 22
          ToPort: 22
      Tags:
        - Key: Name
          Value: !Sub ${AWS::StackName}-compute
  #####
  # VPC Peering
  #####
  VPCPeering:
    Type: AWS::EC2::VPCPeeringConnection
    Properties:
      PeerRegion: !Ref BastionServerRegion
      PeerVpcId: !Ref BastionServerVpcId
      VpcId:
        Fn::ImportValue: !Sub ${VpcStackName}-VPC
      Tags:
        -
          Key: Name
          Value: bastion peering
        -
          Key: Desc
          Value: vpc peering for access via bastion server
  RouteTablePublicInternetRouteA:
    Type: 'AWS::EC2::Route'
    Properties:
      RouteTableId:
        Fn::ImportValue: !Sub ${VpcStackName}-RouteTableAPublic
      DestinationCidrBlock: !Sub ${BastionServerPrivateIP}/32
      VpcPeeringConnectionId: !Ref VPCPeering
  RouteTablePublicInternetRouteB:
    Type: 'AWS::EC2::Route'
    Properties:
      RouteTableId:
        Fn::ImportValue: !Sub ${VpcStackName}-RouteTableBPublic
      DestinationCidrBlock: !Sub ${BastionServerPrivateIP}/32
      VpcPeeringConnectionId: !Ref VPCPeering
  RouteTablePublicInternetRouteC:
    Type: 'AWS::EC2::Route'
    Properties:
      RouteTableId:
        Fn::ImportValue: !Sub ${VpcStackName}-RouteTableCPublic
      DestinationCidrBlock: !Sub ${BastionServerPrivateIP}/32
      VpcPeeringConnectionId: !Ref VPCPeering
  RouteTablePublicInternetRouteD:
    Type: 'AWS::EC2::Route'
    Properties:
      RouteTableId:
        Fn::ImportValue: !Sub ${VpcStackName}-RouteTableDPublic
      DestinationCidrBlock: !Sub ${BastionServerPrivateIP}/32
      VpcPeeringConnectionId: !Ref VPCPeering

Outputs:
  ComputeSecurityGroup:
    Description: security group for general compute instance
    Value: !Ref ComputeSecurityGroup
    Export:
      Name: !Sub ${AWS::StackName}-ComputeSecurityGroup

バッチジョブ層

主にIAMの設定とAWS Batchの設定を書いています
IAMには
* AWS Batchの自動制御を許可するrole AWSBatchServiceRole
* ec2インスタンスにecsの自動制御を許可するrole ecsInstanceRole
* spot fleetの自動制御を許可するrole AmazonEC2SpotFleetTaggingRole
* ecs内のコンテナにawsサービスの利用を許可するrole JobRole
などを設定する必要があります。

AWS Batchの環境構築にはComputeEnvironment, JobDefinition, JobQueueの3層が必要なので設定します。
ECSの利用にECRのリポジトリが必要なのでついでに構築しています。開発環境と本番環境でアカウントが異なる場合でも、それぞれのアカウントにリポジトリを作成する想定です。

AWSTemplateFormatVersion: 2010-09-09
Description: Build AWS Batch environment

Parameters:
  VpcStackName:
    Type: String
  MiddleStackName:
    Type: String
  Ec2KeyPair:
    Type: String
  EcrPushArn:
    Type: String
  EcrRepositoryName:
    Type: String
  ECSAMI:
    Description: AMI ID
    Type: AWS::SSM::Parameter::Value<AWS::EC2::Image::Id>
    Default: /aws/service/ecs/optimized-ami/amazon-linux-2/gpu/recommended/image_id
  EnvType:
    Description: Environment type.
    Default: test
    Type: String
    AllowedValues:
      - prod
      - test
    ConstraintDescription: must specify prod or test.
Conditions:
  IsProd: !Equals [ !Ref EnvType, prod ]

Resources:
  #####
  # IAM
  #####
  AWSBatchServiceRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Version: '2012-10-17'
        Statement:
          - Effect: Allow
            Principal:
              Service:
                - batch.amazonaws.com
            Action:
              - sts:AssumeRole
      ManagedPolicyArns:
        - arn:aws:iam::aws:policy/service-role/AWSBatchServiceRole
      Path: "/service-role/"
  ecsInstanceRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Version: '2012-10-17'
        Statement:
          - Effect: Allow
            Principal:
              Service:
                - ec2.amazonaws.com
            Action:
              - sts:AssumeRole
      ManagedPolicyArns:
        - arn:aws:iam::aws:policy/service-role/AmazonEC2ContainerServiceforEC2Role
  ecsInstanceProfile:
    Type: "AWS::IAM::InstanceProfile"
    Properties:
      Roles:
        - !Ref ecsInstanceRole
  AmazonEC2SpotFleetTaggingRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Version: '2012-10-17'
        Statement:
          - Effect: Allow
            Principal:
              Service:
                - spotfleet.amazonaws.com
            Action:
              - sts:AssumeRole
      ManagedPolicyArns:
        - arn:aws:iam::aws:policy/service-role/AmazonEC2SpotFleetTaggingRole
      Path: "/service-role/"
  JobRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Version: '2012-10-17'
        Statement:
          - Effect: Allow
            Principal:
              Service:
                - ecs-tasks.amazonaws.com
            Action:
              - sts:AssumeRole
      ManagedPolicyArns:
        # set policy your program need
        - arn:aws:iam::aws:policy/AmazonS3FullAccess
        - arn:aws:iam::aws:policy/CloudWatchLogsFullAccess

  #####
  # ECR
  #####
  ECR:
    Type: AWS::ECR::Repository
    Properties:
      RepositoryName: !Ref EcrRepositoryName
      RepositoryPolicyText:
        Version: "2012-10-17"
        Statement:
          - Sid: AllowPushPull
            Effect: Allow
            Principal:
              AWS:
                - !Ref EcrPushArn
            Action:
              - "ecr:GetDownloadUrlForLayer"
              - "ecr:BatchGetImage"
              - "ecr:BatchCheckLayerAvailability"
              - "ecr:PutImage"
              - "ecr:InitiateLayerUpload"
              - "ecr:UploadLayerPart"
              - "ecr:CompleteLayerUpload"

  #####
  # AWS Batch
  #####
  ComputeEnv:
    Type: AWS::Batch::ComputeEnvironment
    Properties:
      # do not specify ComputeEnvironmentName to allow resource update
      Type: MANAGED
      ServiceRole: !GetAtt AWSBatchServiceRole.Arn
      ComputeResources:
        Tags:
          Name: Spot Fleet Managed
          auto_stop: "false"
        Type: SPOT
        MaxvCpus: !If [IsProd, 8, 4] # 本番環境でのインスタンス数*4
        MinvCpus: 0
        InstanceTypes:
          - p2.xlarge
        SecurityGroupIds:
          - Fn::ImportValue: !Sub ${MiddleStackName}-ComputeSecurityGroup
        Subnets:
          Fn::Split:
            - ","
            - Fn::ImportValue: !Sub ${VpcStackName}-SubnetsPublic
        ImageId: !Ref ECSAMI
        Ec2KeyPair: !Ref Ec2KeyPair
        InstanceRole: !GetAtt ecsInstanceProfile.Arn
        SpotIamFleetRole: !GetAtt AmazonEC2SpotFleetTaggingRole.Arn
        BidPercentage: 51
      State: ENABLED
  JobDef:
    Type: AWS::Batch::JobDefinition
    Properties:
      JobDefinitionName: !Sub ${AWS::StackName}-jobdef
      Type: container
      ContainerProperties:
        Image: !Sub ${AWS::AccountId}.dkr.ecr.${AWS::Region}.amazonaws.com/${EcrRepositoryName}:latest
        Vcpus: 4
        Memory: 50000 # MiB
        JobRoleArn: !GetAtt JobRole.Arn
        Privileged: True
        ResourceRequirements:
          - Type: GPU
            Value: 1
      RetryStrategy:
        Attempts: 10 # 1-10 allowed
      Timeout:
        AttemptDurationSeconds: 86400 # 24h
  JobQueue:
    Type: AWS::Batch::JobQueue
    Properties:
      JobQueueName: !Sub ${AWS::StackName}-jobqueue
      ComputeEnvironmentOrder:
        - ComputeEnvironment: !Ref ComputeEnv
          Order: 0
      Priority: 100
      State: ENABLED

CloudFormation起動スクリプト

こんな感じのスクリプトから呼び出します

#!/bin/bash

if [ $# -lt 2 ]; then
  echo 'two argument required. environment, input command'
  exit
fi

env=${1}
if [ ${env} = 'test' ]; then
  profile=ここらへん入れてください
  region=
  vpc_stack_name=
  middle_stack_name=
  application_stack_name=
  bastion_private_ip=
  bastion_vpc_id=
  key_pair_name=
elif [ ${env} = 'prod' ]; then
# 略
else
  echo 'test or prod is required.' >&2
  exit 1
fi


function all () {
  create_key_pair
  create_vpc
  create_middle
  create_batch_env
}

function create_key_pair () {
  # Key pair is created to project root. Move it to suitable place.
  aws ec2 create-key-pair \
    --profile ${profile} \
    --region ${region} \
    --key-name ${key_pair_name} \
    --query 'KeyMaterial' \
    --output text > ${key_pair_name}.pem
}

function create_vpc () {
  aws cloudformation deploy \
    --profile ${profile} \
    --region ${region} \
    --template-file template/networking/vpc-4az.yml \
    --stack-name ${vpc_stack_name} \
    --capabilities CAPABILITY_IAM \
    --parameter-overrides \
      ClassB=0−255までの数字でVPCのIPアドレス指定に使われる \
  ;
}

function create_middle () {
  aws cloudformation deploy \
    --profile ${profile} \
    --region ${region} \
    --template-file template/networking/env_specific.yml \
    --stack-name ${middle_stack_name} \
    --capabilities CAPABILITY_IAM \
    --parameter-overrides \
      VpcStackName=${vpc_stack_name} \
      BastionServerPrivateIP=${bastion_private_ip} \
      BastionServerVpcId=${bastion_vpc_id} \
      BastionServerRegion=ap-northeast-1 \
  ;
}

function create_batch_env () {
  aws cloudformation deploy \
    --profile ${profile} \
    --region ${region} \
    --template-file テンプレート.yml \
    --stack-name ${application_stack_name} \
    --capabilities CAPABILITY_IAM \
    --parameter-overrides \
      VpcStackName=${vpc_stack_name} \
      MiddleStackName=${middle_stack_name} \
      Ec2KeyPair=${key_pair_name} \
      EcrRepositoryName=名前 \
      EcrPushArn=AWSのARN \
      EnvType=${env} \
  ;
}

# execute input command
${2}

まとめ

スポットインスタンスの利用により6-7割引ほどの恩恵を受けられています。コストを削減すればその分を更にたくさんのマシンリソースに充てられます。積極的に使っていきましょう。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
3