LoginSignup
9
3

More than 3 years have passed since last update.

機械学習プロジェクトのフォーマットを用意してSageMakerで実行する

Last updated at Posted at 2019-12-06

こんにちは、株式会社LIFULLの二宮です。

機械学習のプロジェクトでは、分析やモデルの精度評価でうまくいった後、それをうまく既存のシステムで利用できるようにしなければいけません。その際に、自分たちのチームでは実装担当のエンジニアの役割分担に苦労していました。

「データサイエンティストがこのフォーマットで作ってくれれば簡単に組み込めるよ!」という状態を目指して、Amazon SageMakerをラップして、ある程度汎用的に使える開発フォーマットとツールを用意しました。

Amazon SageMakerとは

Amazon SageMaker は、すべての開発者とデータサイエンティストに機械学習モデルの構築、トレーニング、デプロイ手段を提供します。Amazon SageMaker は、機械学習のワークフロー全体をカバーする完全マネージド型サービスです。データをラベル付けして準備し、アルゴリズムを選択して、モデルのトレーニングを行い、デプロイのための調整と最適化を行い、予測を行い、実行します。モデルをより少ない労力と費用で、本番稼働させることができます。

主な機能としては、特定の仕様に合わせたDockerイメージを用意すれば、以下のような機能を利用できます。

SageMakerで自作のDockerイメージを用意する場合の仕様については、公式ドキュメントや @taniyam さん(私と同じチームです)の記事を読んでください。

機械学習プロジェクトのフォーマット

まず、データサイエンティストには次のようなディレクトリ構成で用意してもらうようにしました。

.
├── README.md
├── Dockerfile
├── config.yml
├── pyproject.toml (poetryの設定ファイル)
├── script
│   └── __init__.py
└── tests
    └── __init__.py

script/__init__.py にメインの処理が書かれていて、以下のようなスクリプトになっています。 simple_sagemaker_manager が今回用意したライブラリです。

import pandas as pd
from typing import List
from pathlib import Path
from sklearn import tree
from simple_sagemaker_manager.image_utils import AbstractModel


def train(training_path: Path) -> AbstractModel:
    """学習を行う。

    Args:
        training_path (Path): csvファイルのあるディレクトリ

    Returns:
        Model: AbstractModelを継承したモデルオブジェクト

    """
    train_data = pd.concat([pd.read_csv(fname, header=None) for fname in training_path.iterdir()])
    train_y = train_data.iloc[:, 0]
    train_X = train_data.iloc[:, 1:] 

    # Now use scikit-learn's decision tree classifier to train the model.
    clf = tree.DecisionTreeClassifier(max_leaf_nodes=None)
    clf = clf.fit(train_X, train_y)
    return Model(clf)


class Model(AbstractModel):
    """AbstractModelにシリアライズの方法などが記述されている。
    """

    def predict(self, matrix: List[List[float]]) -> List[List[str]]:
        """推論処理。

        Args:
            matrix (List[List[float]]): テーブルデータ

        Returns:
            list: 推論結果

        """
        # ここで返した結果が推論APIのレスポンスになります
        return [[x] for x in self.model.predict(pd.DataFrame(matrix))]

AbstractModel は次のような定義で、 save メソッドが呼び出された結果(pickleでシリアライズされた結果)が保存され、学習バッチの実行時は(SageMakerのシステムで利用して)これがモデルとしてS3に保存されます。また、 saveload をオーバーライドすることでシリアライズ方式を切り替えられるようにしています。

import pickle
from abc import ABC, abstractmethod
from dataclasses import dataclass


@dataclass
class AbstractModel(ABC):
    model: object

    @classmethod
    def load(cls, model_path):
        # 学習バッチ時にモデルを保存する
        with open(model_path / 'model.pkl', 'rb') as f:
            model = pickle.load(f)
        return cls(model)

    def save(self, model_path):
        # 推論時にモデルを読み込む
        with open(model_path / 'model.pkl', 'wb') as f:
            pickle.dump(self.model, f)

    @abstractmethod
    def predict(self, json):
        pass

Pythonのpoetryなどのプロジェクトを参考にして、cliで操作を行うようにしています。SageMakerのDockerイメージの開発フローは以下のようになります。

  • プロジェクトの雛形を作成する ( smcli new プロジェクト名 )
  • 雛形を編集する
  • イメージをビルドする ( smcli build )
  • イメージをECRにpushする ( smcli push )

また、Dockerfileを編集できるようにしたのは、一部の機械学習ライブラリがAnacondaでしかインストールできないため、「Python3の公式イメージ以外にも挿し替えられるようにしてほしい」という要望を受けたためです。

SageMakerの実行管理

boto3 を直接実行するのもきついので、こちらもラップしたライブラリを用意しました。たくさんの操作がありますが、多くのプロジェクトで我々がやりたいことは「モデルを学習する」「推論APIを立てる OR バッチ変換ジョブを実行する」の3つなので、それがわかるようなインターフェイスにしています。

from simple_sagemaker_manager.executor import SageMakerExecutor
from simple_sagemaker_manager.executor.classes import TrainInstance, TrainSpotInstance, Image


client = SageMakerExecutor()

# 通常インスタンスで学習する場合
model = client.execute_batch_training(
    instance=TrainInstance(
        instance_type='ml.m4.xlarge',
        instance_count=1,
        volume_size_in_gb=10,
        max_run=100
    ),
    image=Image(
        name="decision-trees-sample",
        uri="xxxxxxxxxx.dkr.ecr.ap-northeast-1.amazonaws.com/decision-trees-sample:latest"
    ),
    input_path="s3://xxxxxxxxxx/DEMO-scikit-byo-iris",
    output_path="s3://xxxxxxxxxx/output",
    role="arn:aws:iam::xxxxxxxxxx"
)


# スポットインスタンスで学習する場合
model = client.execute_batch_training(
    instance=TrainSpotInstance(
        instance_type='ml.m4.xlarge',
        instance_count=1,
        volume_size_in_gb=10,
        max_run=100,
        max_wait=1000
    ),
    image=Image(
        name="decision-trees-sample",
        uri="xxxxxxxxxx.dkr.ecr.ap-northeast-1.amazonaws.com/decision-trees-sample:latest"
    ),
    input_path="s3://xxxxxxxxxx/DEMO-scikit-byo-iris",
    output_path="s3://xxxxxxxxxxx/output",
    role="arn:aws:iam::xxxxxxxxxxxxx"
)

推論APIは以下のように作っています。工夫した点としては以下の通りです。

  • 指定した名前(name)のエンドポイントが存在しない場合はエンドポイントを新規作成する
  • 存在する場合はupdateを行う。Updatingの間もリクエストは受け付ける。
  • modelsをリストで受け取れるようにした。複数のモデルを指定した場合、Pipelineモデルを作成した上でデプロイを行う。
from simple_sagemaker_manager.executor import SageMakerExecutor
from simple_sagemaker_manager.executor.classes import EndpointInstance, Model

client = SageMakerExecutor()


# 特定のモデルをデプロイする場合
# modelsに複数のモデルを指定すると、Pipelineモデルを作成して利用します
client.deploy_endpoint(
    instance=EndpointInstance(
        instance_type='ml.m4.xlarge',
        initial_count=1,
        initial_variant_wright=1
    ),
    models=[
        Model(
            name='decision-trees-sample-191028-111309-538454',
            model_arn='arn:aws:sagemaker:ap-northeast-1:xxxxxxxxxx',
            image_uri='xxxxxxxxxx.dkr.ecr.ap-northeast-1.amazonaws.com/decision-trees-sample:latest',
            model_data_url='s3://xxxxxxxxxx/model.tar.gz'
        )
    ],
    name='sample-endpoint',
    role="arn:aws:iam::xxxxxxxxxx"
)

# execute_batch_trainingの結果を渡すこともできます
model = client.execute_batch_training(
    # 引数は略
) 

client.deploy_endpoint(
    instance=EndpointInstance(
        instance_type='ml.m4.xlarge',
        initial_count=1,
        initial_variant_wright=1
    ),
    models=[model],
    name='sample-endpoint',
    role="arn:aws:iam::xxxxxxxxxx"
)

エンドポイント以外(学習バッチジョブなど)の名前は、自動で現在時刻の文字列を足して重複を避ける実装にしています。ただしエンドポイントだけは「同じ名前のものがあればアップデートする」という挙動にして利便性を高めています。

また、省略しますが、バッチ変換ジョブのメソッドも同様に実装しています。

これからの課題

このように実装して、今は一部のプロジェクトの実装で実際に利用してもらっています。ただ未実装の課題もいくつかあり、チーム内の他の課題もまだまだあります。

  • テーブルデータ以外(画像など)のコンテナイメージの雛形は未実装です
  • 学習バッチのSpotInstanceで、処理の中断時に中間状態をS3に保存する機能
  • ローカルで推論APIを立ち上げるなど、APIの実装やテストを支援する機能が無い
  • 学習時のメトリクスの取得。これは別メンバーがMLflowに注目して検証しています。
  • データの分析やモデルの精度評価についてもまだまだ課題が多そう

また、チーム内で実際に利用してもらうと、使い勝手の悪い部分もそれなりに出てきてしまっているので、それらの問題を解決してより機械学習プロジェクトを効率化していこうと思います。

9
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
3