0
3

More than 3 years have passed since last update.

自前のXGBoostコードをSageMakerコードに変換する

Last updated at Posted at 2020-05-10

PoCやKaggleで書くコードを、SageMakerで仕組み化する場合にどうコードを変換するのかわかりにくかったので、まとめる。

変換前

乳がんの二値分類をXGboostを用いて判別したコード。
こちらのサイトを参考にさせていただいた。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import xgboost as xgb
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

"""XGBoost で二値分類するサンプルコード"""
# 乳がんデータセットを読み込む
dataset = datasets.load_breast_cancer()
x, y = dataset.data, dataset.target
# データセットを学習用とテスト用に分割する
train_x, test_x, train_y, test_y = train_test_split(x, y,
                                                    test_size=0.2,
                                                    shuffle=True,
                                                    random_state=42,
                                                    stratify=y)
# さらに学習用データを学習用とvalid用に分割する
tr_x, va_x, tr_y, va_y = train_test_split(train_x, train_y,
                                                    test_size=0.2,
                                                    shuffle=True,
                                                    random_state=42,
                                                    stratify=train_y)
# XGBoost が扱うデータセットの形式に直す
dtrain = xgb.DMatrix(tr_x, label=tr_y)
dvalid = xgb.DMatrix(va_x, label=va_y)
dtest = xgb.DMatrix(test_x)
# 学習用のパラメータ
xgb_params = {
    # 二値分類問題
    'objective': 'binary:logistic',
    # 評価指標
    'eval_metric': 'logloss',
}
# モデルを学習する
# バリデーションデータもモデルに渡し、学習の進行とともにスコアがどう変わるかモニタリングする
# watchlistには学習データおよびバリデーションデータをセットする
watchlist = [(dtrain, 'train'), (dvalid, 'eval')]
model = xgb.train(xgb_params,
                dtrain,
                num_boost_round=50,  # 学習ラウンド数は適当
                evals=watchlist
                )
# 予測:検証用データが各クラスに分類される確率を計算する
pred_proba = model.predict(dtest)
# しきい値 0.5 で 0, 1 に丸める
pred = np.where(pred_proba > 0.5, 1, 0)
# 精度 (Accuracy) を検証する
acc = accuracy_score(test_y, pred)
print('Accuracy:', acc)

以上、よくある形式。52行

SageMakerでXGBoostを行う場合

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import xgboost as xgb
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

"""XGBoost で二値分類するサンプルコード"""
# 乳がんデータセットを読み込む
dataset = datasets.load_breast_cancer()
x, y = dataset.data, dataset.target
# データセットを学習用とテスト用に分割する
train_x, test_x, train_y, test_y = train_test_split(x, y,
                                                    test_size=0.2,
                                                    shuffle=True,
                                                    random_state=42,
                                                    stratify=y)
# さらに学習用データを学習用とvalid用に分割する
tr_x, va_x, tr_y, va_y = train_test_split(train_x, train_y,
                                                    test_size=0.2,
                                                    shuffle=True,
                                                    random_state=42,
                                                    stratify=train_y)
##### ここから異なる ####################################################
import pandas as pd
### 最初の列に目的変数を挿入
tr = np.insert(tr_x, 0, tr_y, axis=1)
va = np.insert(va_x, 0, va_y, axis=1)

# ローカルにcsv作成(ヘッダなし)
pd.DataFrame(tr).to_csv('train.csv', header=False, index=False)
pd.DataFrame(va).to_csv('validation.csv', header=False, index=False)

### S3にアップ
import sagemaker
from sagemaker import get_execution_role

# default_bucket=None を任意で指定することも可能
# https://sagemaker.readthedocs.io/en/stable/session.html
sagemaker_session = sagemaker.Session()

# バケット: sagemaker-<region>-<アカウントID> に保存される。(上記で指定しなかったため)
input_train = sagemaker_session.upload_data(path='train.csv', key_prefix='sagemaker/xgb-breast-cancer')
input_validation = sagemaker_session.upload_data(path='validation.csv', key_prefix='sagemaker/xgb-breast-cancer')

### content_typeを指定する
from sagemaker.session import s3_input

s3_input_train = s3_input(s3_data=input_train, content_type='text/csv')
s3_input_validation = s3_input(s3_data=input_validation, content_type='text/csv')

### SageMaker XGboostのコンテナパス取得
import boto3
from sagemaker.amazon.amazon_estimator import get_image_uri
container = get_image_uri(boto3.Session().region_name, 'xgboost')

### estimatorの設定
role = get_execution_role() # sagemaker ノートブックインスタンス、Studioからのみ有効

sess = sagemaker.Session()

xgb = sagemaker.estimator.Estimator(container,
                                    role, 
                                    train_instance_count=1, 
                                    train_instance_type='ml.m4.xlarge',
                                    sagemaker_session=sess)
### ハイパーパラメータの設定
### num_roundはここでしか宣言できない
xgb.set_hyperparameters(
    objective='binary:logistic',
    eval_metric='logloss',
    num_round=50
)

### トレーニングジョブの実行
xgb.fit({'train': s3_input_train, 'validation': s3_input_validation})

### エンドポイントのデプロイ(予測を実施するサーバを作成)
xgb_predictor = xgb.deploy(initial_instance_count = 1, instance_type = 'ml.m4.xlarge')

from sagemaker.predictor import csv_serializer

xgb_predictor.content_type = 'text/csv'
xgb_predictor.serializer = csv_serializer
xgb_predictor.deserializer = None

### 予測の実行
pred_proba_sagemaker = xgb_predictor.predict(test_x).decode('utf-8')

### 整形
pred_proba_sagemaker_arr = np.fromstring(pred_proba_sagemaker[0:], sep=',')

##### ここまで。以後同じ ####################################################
pred_sagemaker = np.where(pred_proba_sagemaker_arr > 0.5, 1, 0)

# 精度 (Accuracy) を検証する
acc_sagemaker = accuracy_score(test_y, pred_sagemaker)
print('Accuracy:', acc_sagemaker)

100行。

詳細

後で書く

まとめ

・単にモデルを作って検証するだけなら、SageMakerの機能は使わない方がスムーズ。強いて言えば、GPUで並列処理が必要な巨大が学習機能が必要な場合は、SpotインスタンスのPシリーズ並列学習を利用する。くらい。

・SageMaker Studioで実施すれば、都度インスタンスを切り替える必要なし。かつ、インスタンス上でSageMaker用処理なしで実験可能。

・実験は、SageMaker Experienmentsを利用して、自動的に管理される。(個人の管理意識によらない)

・PoC段階で、「そもそもビジネスにインパクトを出せるモデルを作れるか」という検証では、SageMaker Studioか、ローカルで十分(ローカルはライブラリ入れるのが面倒だが)。R、Sparkを使いたい場合はSageMaker ノートブックインスタンス or EC2。

・PoCでビジネスバリューが出せると判断し、仕組み化する場合に、SageMakerのマネージド機能が真価を発揮する。(日次バッチで学習、推論サーバのデプロイ)

参考

Python: XGBoost を使ってみる
https://blog.amedama.jp/entry/2019/01/29/235642

XGBoost による顧客離反分析 (Churn Analysis)
https://github.com/aws-samples/amazon-sagemaker-examples-jp/blob/master/xgboost_customer_churn/xgboost_customer_churn.ipynb

SageMaker Python SDKのローカルモードを利用して、ノートブックインスタンス以外の環境で学習ジョブを回してみる
https://dev.classmethod.jp/articles/sagemaker-python-sdk-localmode/

推論パイプラインでリアルタイム予測を実行
https://docs.aws.amazon.com/ja_jp/sagemaker/latest/dg/inference-pipeline-real-time.html

0
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
3