2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

機械学習のスキルを身に着けたい人向けにチュートリアルの内容を整理しました。

今回は、機械学習モデルをトレーニングする チュートリアルを活用します。
https://aws.amazon.com/jp/getting-started/hands-on/machine-learning-tutorial-train-a-model/

機械学習モデルの構築からデプロイまでの一般的なワークフローを確認する事ができます。ハイパーパラメータの調整、モデルのバイアス検証、デプロイなどのステップを組み合わせて、機械学習プロセスを構築しています。

【身に着ける事が出来るスキル】
Amazon SageMaker Studio と Amazon SageMaker Clarify を使用して、
機械学習 (ML) モデルをトレーニング、チューニング、評価する事が出来るようになります。

・スクリプトモードを使用したモデルの構築、トレーニング、およびチューニング
・ML モデルのバイアスを検出し、モデルの予測
・トレーニングされたモデルをリアルタイム推論エンドポイントにデプロイし、テスト
・サンプル予測を生成し、特微量の影響を理解することにより、モデルを評価

このチュートリアルでは、合成的に生成された自動車保険の請求のデータセットを使用します。入力はトレーニング、検証、テストデータセットで、それぞれ請求と顧客に関する詳細や抽出された特徴量、そして請求が不正かそうでないかを示す不正列を含んでいます。オープンソースの XGBoost フレームワークを使用して、この合成データセットで二項分類モデルを構築し、請求が不正である可能性を予測します。 また、バイアスと特微量の重要性レポートを実行してトレーニング済みモデルを評価し、テスト用にモデルをデプロイし、サンプル推論を実行してモデルのパフォーマンスを評価し、予測を説明することができます。

請求が不正である可能性を予測するモデルを作成する機械学習のプロセス

イメージ
+---------------------+ +---------------------+ +---------------------+
| ①問題の定義 ②データの準備 ③データの集合と統合
|
| 🎯 📊 🔀
|
+-----------+---------+ +---------------------+ +---------------------+
①不正請求の予測モデルの目標
②請求データ、顧客情報、特徴量の抽出
③自動車保険の請求データセットの統合
|
v
+----------------------------------------------------------------+
| データの可視化と分析                
| 📈
| 散布図や統計的手法を使用してデータの傾向や関係性を観察
+----------------------------------------------------------------+
|
v
+----------------------------------------------------------------+
| 特徴量エンジニアリング
| 🛠️
| 特徴量を適切に選択し、必要に応じて変換や生成を行う
+----------------------------------------------------------------+
|
v
+----------------------------------------------------------------+
| モデルトレーニング
| 🤖
| 選択した機械学習アルゴリズムを使用してモデルをトレーニング
+----------------------------------------------------------------+
|
v
+----------------------------------------------------------------+
| モデル評価
| 📊
| テストデータでモデルの性能を評価
+----------------------------------------------------------------+
|
v
+----------------------------------------------------------------+
| 本番環境の準備:予測の実行
| 🚀
| 新しいデータに対してモデルを使用して予測
+----------------------------------------------------------------+

  1. 問題の定義 (①問題の定式化)

    • 目標: 請求が不正であるかどうかを二項分類するモデルを作成する。
    • 入力: 請求と顧客に関する詳細や抽出された特徴量、不正列。
  2. データの準備 (②データの準備)

    • トレーニングデータ、検証データ、テストデータにデータセットを分割。
    • 特徴量エンジニアリング: 適切な特徴量を選択し、変換や生成を行う。
    • ラベル付け: 不正列を目的変数として利用。
  3. データの集合と統合 (③データの集合と統合)

    • 合成的に生成された自動車保険の請求データセットを統合。
  4. データの可視化と分析

    • データの傾向や関係性を散布図や統計的手法を使用して観察。
  5. 特徴量エンジニアリング

    • 特徴量の選択や変換を行い、モデルに適した形にデータを整形。
  6. モデルトレーニング

    • 選択した機械学習アルゴリズム(XGBoostなど)を使用してモデルをトレーニング。
    • トレーニングデータに基づいてモデルが不正請求を予測するパターンを学習。
  7. モデル評価

    • テストデータでモデルの性能を評価し、精度や再現率、適合率などの評価指標を確認。
    • モデルが不正請求をどれくらい正確に予測できるかを検証。
  8. 本番環境の準備:予測の実行

    • モデルを本番環境にデプロイし、新しいデータに対して予測を行えるように準備。
    • 予測の実行時、新しい請求データに基づいてモデルが不正請求の可能性を予測。

このステップにより、請求が不正である可能性を予測するモデルが構築されます。このモデルは、実際のデータを基にトレーニングされ、新しいデータに対して予測を行うことができます。

ステップ 1: Amazon SageMaker Studio ドメインを設定する

【前提】
バージニア北部リージョンで、VPCとサブネットを作成しておく必要があります。

ドメインを作成する(今回は、シングルユーザ向けの設定ボタンを押します。)
image.png

ドメインが作成中となる
image.png

ドメインが作成される
image.png

ユーザプロファイルを作成する
image.png

ユーザプロファイル設定を行う(今回は、そのままの設定を活用します)
image.png

ユーザプロファイル設定を行う(今回は、そのままの設定を活用します)
image.png

ユーザプロファイル設定を行う(今回は、そのままの設定を活用します)
image.png

ユーザプロファイル設定を行う(今回は、そのままの設定を活用します)
image.png

ユーザプロファイルが作成される
image.png

Studioを開く
image.png

ステップ 2: SageMaker Studio ノートブックを設定する

Studio Classicを開く
image.png

Runで起動する
image.png

Openで開く
image.png

インターフェスを確認する
image.png

ノートブックを開く
image.png

Kernelをスタート
image.png

Kernelがスタート中となる
image.png

ライブラリをインポートする
image.png

動作確認済みコード(2024/1/27)
チュートリアルのままやろうとしても、エラーになるので次のように修正します。

%pip install -q  xgboost==1.3.1
import pandas pd

【実行後の画面】
image.png

S3 クライアントオブジェクトと、メトリクスやモデルアーティファクトなどのコンテンツがアップロードされるデフォルトの S3 バケット内の場所をインスタンス化する
image.png

動作確認済みコード(2024/1/27)

import pandas as pd
import boto3
import sagemaker
import json
import joblib
from sagemaker.xgboost.estimator import XGBoost
from sagemaker.tuner import (
    IntegerParameter,
    ContinuousParameter,
    HyperparameterTuner
)
from sagemaker.inputs import TrainingInput
from sagemaker.image_uris import retrieve
from sagemaker.serializers import CSVSerializer
from sagemaker.deserializers import CSVDeserializer

# Setting SageMaker variables
sess = sagemaker.Session()
write_bucket = sess.default_bucket()
write_prefix = "fraud-detect-demo"

region = sess.boto_region_name
s3_client = boto3.client("s3", region_name=region)

sagemaker_role = sagemaker.get_execution_role()
sagemaker_client = boto3.client("sagemaker")
read_bucket = "sagemaker-sample-files"
read_prefix = "datasets/tabular/synthetic_automobile_claims" 


# Setting S3 location for read and write operations
train_data_key = f"{read_prefix}/train.csv"
test_data_key = f"{read_prefix}/test.csv"
validation_data_key = f"{read_prefix}/validation.csv"
model_key = f"{write_prefix}/model"
output_key = f"{write_prefix}/output"


train_data_uri = f"s3://{read_bucket}/{train_data_key}"
test_data_uri = f"s3://{read_bucket}/{test_data_key}"
validation_data_uri = f"s3://{read_bucket}/{validation_data_key}"
model_uri = f"s3://{write_bucket}/{model_key}"
output_uri = f"s3://{write_bucket}/{output_key}"
estimator_output_uri = f"s3://{write_bucket}/{write_prefix}/training_jobs"
bias_report_output_uri = f"s3://{write_bucket}/{write_prefix}/clarify-output/bias"
explainability_report_output_uri = f"s3://{write_bucket}/{write_prefix}/clarify-output/explainability"

提供されたコードは、AWS SageMakerを使用してXGBoostモデルをトレーニングおよびデプロイするためのセットアップを行っています。入力データ、モデルアーティファクト、およびさまざまな出力レポートの場所をS3に指定しています。

ライブラリのインポート

  • import pandas as pd: データ操作のためのpandasライブラリをインポートし、pdとしてエイリアスを付けます。
  • import boto3: Python用のAWS SDKであるboto3ライブラリをインポートします。
  • import sagemaker: AWS SageMaker上で機械学習を行うためのSageMakerライブラリをインポートします。
  • import json: JSONデータを処理するためのJSONライブラリをインポートします。
  • import joblib: Pythonでの並列処理のためのjoblibライブラリをインポートします。
  • from sagemaker.xgboost.estimator import XGBoost: SageMakerからXGBoostモデルをトレーニングするためのXGBoostエスティメータをインポートします。
  • from sagemaker.tuner import (IntegerParameter, ContinuousParameter, HyperparameterTuner): SageMakerでのハイパーパラメータの調整のためのクラスをインポートします。
  • from sagemaker.inputs import TrainingInput: SageMakerトレーニングジョブの入力データを定義するためのTrainingInputクラスをインポートします。
  • from sagemaker.image_uris import retrieve: 特定のアルゴリズムコンテナのイメージURIを取得するためのretrieve関数をインポートします。
  • from sagemaker.serializers import CSVSerializer: 入力データをCSV形式でシリアライズするためのCSVSerializerをインポートします。
  • from sagemaker.deserializers import CSVDeserializer: 出力データをCSV形式でデシリアライズするためのCSVDeserializerをインポートします。

SageMaker変数の設定

  • sess = sagemaker.Session(): SageMakerセッションを作成します。
  • write_bucket = sess.default_bucket(): SageMakerセッションに関連付けられたデフォルトのS3バケットを取得します。
  • write_prefix = "fraud-detect-demo": S3キーを整理するためのプレフィックスを設定します。
  • region = sess.boto_region_name: SageMakerセッションからAWSリージョンを取得します。
  • s3_client = boto3.client("s3", region_name=region): S3と対話するためにboto3を使用してS3クライアントを作成します。
  • sagemaker_role = sagemaker.get_execution_role(): SageMaker実行ロールを取得します。
  • sagemaker_client = boto3.client("sagemaker"): boto3を使用してSageMakerクライアントを作成します。
  • read_bucket = "sagemaker-sample-files": 入力データセットが保存されているS3バケットを設定します。
  • read_prefix = "datasets/tabular/synthetic_automobile_claims": 入力データセットが保存されているS3キーのプレフィックスを設定します。

読み書きのためのS3ロケーションの設定

  • トレーニング、テスト、検証データセット、モデルアーティファクト、および出力のS3キーを定義します。
  • 対応するリソースのためのS3 URIを作成します。

追加のS3 URI

  • SageMakerトレーニングジョブの出力、バイアスレポート、および説明性レポートの出力用に追加のS3 URIを定義します。

【実行後の画面】
image.png

image.png

image.png

モデル名とトレーニングおよび推論インスタンスの構成と回数を設定します。これらの設定により、適切なインスタンスタイプとカウントを使用して、トレーニングと推論のプロセスを管理することができます。

image.png

動作確認済みコード(2024/1/27)

tuning_job_name_prefix = "xgbtune" 
training_job_name_prefix = "xgbtrain"

xgb_model_name = "fraud-detect-xgb-model"
endpoint_name_prefix = "xgb-fraud-model-dev"
train_instance_count = 1
train_instance_type = "ml.m4.xlarge"
predictor_instance_count = 1
predictor_instance_type = "ml.m4.xlarge"
clarify_instance_count = 1
clarify_instance_type = "ml.m4.xlarge"

デプロイするための設定を行っています。これらの設定は、モデルのトレーニング、ハイパーパラメータの調整、推論エンドポイントのデプロイ、データバイアスの検証など、SageMakerを使用して機械学習ワークフローを構築するための基本的な構成を指定しています。

tuning_job_name_prefix ハイパーパラメータの調整(hyperparameter tuning)を行うための SageMaker チューニングジョブの名前の接頭辞です。xgbtune という名前が使われます。

training_job_name_prefix モデルのトレーニングに使用される SageMaker トレーニングジョブの名前の接頭辞です。xgbtrain という名前が使われます。

xgb_model_name XGBoostモデルの名前です。fraud-detect-xgb-model という名前が使われます。

endpoint_name_prefix デプロイされたエンドポイントの名前の接頭辞です。xgb-fraud-model-dev という名前が使われます。

train_instance_count トレーニングに使用されるインスタンスの数です。ここでは1つのインスタンス (train_instance_count = 1) が指定されています。

train_instance_type トレーニングに使用されるインスタンスのタイプです。ml.m4.xlarge という、中程度のリソースを持つインスタンスが指定されています。

predictor_instance_count 推論(inference)のためのエンドポイントに使用されるインスタンスの数です。ここでは1つのインスタンス (predictor_instance_count = 1) が指定されています。

predictor_instance_type 推論に使用されるエンドポイントのインスタンスのタイプです。ml.m4.xlarge という、中程度のリソースを持つインスタンスが指定されています。

clarify_instance_count データの偏りやモデルの誤りを検証するための SageMaker Clarify ジョブに使用されるインスタンスの数です。ここでは1つのインスタンス (clarify_instance_count = 1) が指定されています。

clarify_instance_type SageMaker Clarify ジョブに使用されるインスタンスのタイプです。ml.m4.xlarge という、中程度のリソースを持つインスタンスが指定されています。

【実行後の画面】
image.png

image.png

ステップ 3: スクリプトモードでのハイパーパラメータチューニングジョブの起動

SageMaker Studio では、Python スクリプト内に独自のロジックを持ち込んでトレーニングに使用することができます。トレーニングロジックをスクリプトでカプセル化することで、AWS が管理する共通の機械学習フレームワークコンテナを使用しながら、カスタムトレーニングルーチンやモデル構成を取り入れることができます。このチュートリアルでは、AWS が提供する XGBoost コンテナがサポートするオープンソースの XGBoost フレームワークを使用したトレーニングスクリプトを用意し、ハイパーパラメータのチューニングジョブを大規模に起動します。モデルのトレーニングには、ターゲットカラムとして不正列を使用します。

image.png

動作確認済みコード(2024/1/27)

%%writefile xgboost_train.py

import argparse
import os
import joblib
import json
import pandas as pd
import xgboost as xgb
from sklearn.metrics import roc_auc_score

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Hyperparameters and algorithm parameters are described here
    parser.add_argument("--num_round", type=int, default=100)
    parser.add_argument("--max_depth", type=int, default=3)
    parser.add_argument("--eta", type=float, default=0.2)
    parser.add_argument("--subsample", type=float, default=0.9)
    parser.add_argument("--colsample_bytree", type=float, default=0.8)
    parser.add_argument("--objective", type=str, default="binary:logistic")
    parser.add_argument("--eval_metric", type=str, default="auc")
    parser.add_argument("--nfold", type=int, default=3)
    parser.add_argument("--early_stopping_rounds", type=int, default=3)
    

    # SageMaker specific arguments. Defaults are set in the environment variables
    # Location of input training data
    parser.add_argument("--train_data_dir", type=str, default=os.environ.get("SM_CHANNEL_TRAIN"))
    # Location of input validation data
    parser.add_argument("--validation_data_dir", type=str, default=os.environ.get("SM_CHANNEL_VALIDATION"))
    # Location where trained model will be stored. Default set by SageMaker, /opt/ml/model
    parser.add_argument("--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR"))
    # Location where model artifacts will be stored. Default set by SageMaker, /opt/ml/output/data
    parser.add_argument("--output_data_dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR"))
    
    args = parser.parse_args()

    data_train = pd.read_csv(f"{args.train_data_dir}/train.csv")
    train = data_train.drop("fraud", axis=1)
    label_train = pd.DataFrame(data_train["fraud"])
    dtrain = xgb.DMatrix(train, label=label_train)
    
    
    data_validation = pd.read_csv(f"{args.validation_data_dir}/validation.csv")
    validation = data_validation.drop("fraud", axis=1)
    label_validation = pd.DataFrame(data_validation["fraud"])
    dvalidation = xgb.DMatrix(validation, label=label_validation)

    params = {"max_depth": args.max_depth,
              "eta": args.eta,
              "objective": args.objective,
              "subsample" : args.subsample,
              "colsample_bytree":args.colsample_bytree
             }
    
    num_boost_round = args.num_round
    nfold = args.nfold
    early_stopping_rounds = args.early_stopping_rounds
    
    cv_results = xgb.cv(
        params=params,
        dtrain=dtrain,
        num_boost_round=num_boost_round,
        nfold=nfold,
        early_stopping_rounds=early_stopping_rounds,
        metrics=["auc"],
        seed=42,
    )
    
    model = xgb.train(params=params, dtrain=dtrain, num_boost_round=len(cv_results))
    
    train_pred = model.predict(dtrain)
    validation_pred = model.predict(dvalidation)
    
    train_auc = roc_auc_score(label_train, train_pred)
    validation_auc = roc_auc_score(label_validation, validation_pred)
    
    print(f"[0]#011train-auc:{train_auc:.2f}")
    print(f"[0]#011validation-auc:{validation_auc:.2f}")

    metrics_data = {"hyperparameters" : params,
                    "binary_classification_metrics": {"validation:auc": {"value": validation_auc},
                                                      "train:auc": {"value": train_auc}
                                                     }
                   }
              
    # Save the evaluation metrics to the location specified by output_data_dir
    metrics_location = args.output_data_dir + "/metrics.json"
    
    # Save the model to the location specified by model_dir
    model_location = args.model_dir + "/xgboost-model"

    with open(metrics_location, "w") as f:
        json.dump(metrics_data, f)

    with open(model_location, "wb") as f:
        joblib.dump(model, f)

このスクリプトは、XGBoostを使用してバイナリ分類モデルをトレーニングし、評価メトリクスとモデル自体を保存するためのものです。

引数の取得
argparseを使用して、スクリプトに渡されるハイパーパラメータやSageMakerの特定の引数を取得します。これにはトレーニングデータや検証データのディレクトリ、モデルの保存場所などが含まれます。

データの読み込み
トレーニングデータと検証データを読み込み、XGBoostが処理できる形式に変換します。

XGBoostハイパーパラメータの設定
argparseで受け取ったハイパーパラメータを使用して、XGBoostモデルのトレーニングに使用するパラメータを設定します。

クロスバリデーションによるモデルトレーニング
xgb.cvを使用して、クロスバリデーションによるトレーニングを行います。この際、指定されたメトリクス(ここではAUC)を監視し、早期停止を実装します。

最終モデルのトレーニング
クロスバリデーションで得られた最適なラウンド数を用いて最終モデルのトレーニングを行います。

モデルの評価
トレーニングデータと検証データに対するAUCスコアを計算して出力します。

メトリクスとモデルの保存
トレーニングと評価のメトリクスをJSONファイルに保存します。
トレーニングされたモデルをjoblib形式で保存します。

【実行後の画面】
image.png

image.png

image.png

image.png

トレーニングスクリプトを準備したら、SageMaker 推定子をインスタンス化することができます。カスタムスクリプトを実行できる XGBoost コンテナを管理するため、AWS が管理する XGBoost 推定子を使用します。
XGBoost 推定子をインスタンス化するには、次の内容を実行します。

image.png

動作確認済みコード(2024/1/27)

# SageMaker estimator

# Set static hyperparameters that will not be tuned
# チューニングされない静的ハイパーパラメータを設定
static_hyperparams = {  
                        "eval_metric" : "auc",
                        "objective": "binary:logistic",
                        "num_round": "5"
                      }

# XGBoost Estimatorを構築
xgb_estimator = XGBoost(
                        entry_point="xgboost_train.py",  # トレーニングスクリプトのファイル名
                        output_path=estimator_output_uri,  # トレーニングアーティファクトの出力先S3 URI
                        code_location=estimator_output_uri,  # トレーニングスクリプトのコードの出力先S3 URI
                        hyperparameters=static_hyperparams,  # 静的ハイパーパラメータの設定
                        role=sagemaker_role,  # SageMakerトレーニングジョブの実行に使用するIAMロール
                        instance_count=train_instance_count,  # トレーニングインスタンスの数
                        instance_type=train_instance_type,  # トレーニングインスタンスのタイプ
                        framework_version="1.3-1",  # 使用するXGBoostのフレームワークバージョン
                        base_job_name=training_job_name_prefix  # トレーニングジョブのベース名
                    )

XGBoost Estimatorが構築され、その設定がxgb_estimatorとして保存されます。このEstimatorを使用してモデルをトレーニングするためのSageMakerトレーニングジョブを実行できます。

このコードは、SageMakerのEstimatorを設定しています。

静的ハイパーパラメーターの設定

  • static_hyperparams: チューニングされない静的なハイパーパラメーターを設定します。
  • ここでは、評価指標("eval_metric")、目的関数("objective")、ラウンド数("num_round")が設定されています。

XGBoost Estimatorの作成

  • xgb_estimator = XGBoost(...): XGBoost Estimatorを作成します。
  • entry_point: トレーニングスクリプトのエントリーポイントファイルへのパスを指定します。
  • output_path: トレーニングジョブの出力のS3パスを指定します。
  • code_location: トレーニングジョブのコードのS3ロケーションを指定します。
  • hyperparameters: ハイパーパラメーターを指定します。ここでは、静的ハイパーパラメーターが指定されています。
  • role: SageMakerが使用するIAMロールを指定します。
  • instance_count: トレーニングインスタンスの数を指定します。
  • instance_type: トレーニングインスタンスのタイプを指定します。
  • framework_version: 使用するXGBoostのフレームワークバージョンを指定します。
  • base_job_name: トレーニングジョブのベース名を指定します。

Estimatorを作成する主な理由
Amazon SageMakerで機械学習モデルのトレーニングやデプロイを簡単かつ効果的に行うためです。

トレーニングジョブの管理 Estimatorは、トレーニングジョブの起動、監視、停止などを管理します。ユーザーは、モデルのトレーニングに関連する詳細な手順や設定を直接扱わずに、Estimatorを使用してこれらのタスクを容易に実行できます。

モデルのデプロイ Estimatorはトレーニングされたモデルをデプロイ可能なエンドポイントにデプロイするための機能も提供します。これにより、トレーニングしたモデルを実際のデータに適用して推論を行うためのエンドポイントが簡単に作成できます。

Dockerコンテナの構築 SageMakerはDockerコンテナを使用してカスタムトレーニングスクリプトを実行します。Estimatorは、これらのコンテナの作成と管理を行い、ユーザーはDockerに関する詳細な設定を気にする必要がありません。

ハイパーパラメータの設定 Estimatorを使用すると、ハイパーパラメータの設定が簡単に行えます。ユーザーは、Estimatorのハイパーパラメータとして必要な値を指定するだけで、自動的にハイパーパラメータの探索ジョブを実行することもできます。

分散トレーニングのサポート:SageMaker Estimatorは、分散トレーニングに対応しており、複数のインスタンスを使用して効率的にトレーニングを行うことができます。

Estimatorを使用することで、SageMakerの機能をより効果的に活用し、ユーザーが手動で多くの詳細なタスクを処理する必要がなくなります。これにより、機械学習プロジェクトの開発や管理が簡素化され、生産性が向上します。

【実行後の画面】
image.png

image.png

推定子を指定する際に、静的な設定パラメータを指定することができます。このチュートリアルでは、評価メトリクスとして Receiver Operating Characteristics Area Under the Curve (ROC-AUC) を使用します。実行にかかる時間を制御するため、ラウンド数を 5 回に設定しています。

カスタムスクリプトとトレーニングインスタンスの構成は、推定子オブジェクトに引数として渡されます。XGBoost のバージョンは、先にインストールしたものと同じものが選ばれています。
image.png

4 つの XGBoost のハイパーパラメータをチューニングします。

eta: 過剰適合を防ぐために更新で使用されるステップサイズの縮小。各ブースティングステップの後、新しい特徴量の重みを直接取得することができます。eta パラメータは実際に特徴量の重みを縮小し、ブースティング処理をより保守的にします。

サブサンプル: トレーニングインスタンスのサブサンプル比率。0.5 に設定すると、XGBoost はツリーを成長させる前にトレーニングデータの半分をランダムにサンプリングすることを意味します。ブースティングの反復ごとに異なるサブセットを使用することで、過剰適合を防ぐことができます。

colsample_bytree: ブースティング処理の各ツリーを生成するために使用される特徴量の一部。各ツリーの生成に特徴量のサブセットを用いることで、モデリングプロセスのランダム性を高め、一般化率を向上させます。

max_depth: ツリーの最大深度。この値を大きくすると、モデルがより複雑になり、過剰適合する可能性が高くなります。

image.png

動作確認済みコード(2024/1/27)

# Setting ranges of hyperparameters to be tuned
hyperparameter_ranges = {
    "eta": ContinuousParameter(0, 1),
    "subsample": ContinuousParameter(0.7, 0.95),
    "colsample_bytree": ContinuousParameter(0.7, 0.95),
    "max_depth": IntegerParameter(1, 5)
}

機械学習モデルのハイパーパラメータを調整するための範囲を設定しています。

"eta" 学習率を示しており、0から1の間の連続的な値をとります。
"subsample" サンプリングの割合を示しており、0.7から0.95の間の連続的な値をとります。
"colsample_bytree" ツリーを構築する際に特徴量のサンプリング割合を示しており、0.7から0.95の間の連続的な値をとります。
"max_depth" 決定木の最大の深さを示しており、1から5の間の整数値をとります。

【実行後の画面】
image.png

image.png

ハイパーパラメータチューナーを設定します。

SageMaker は、検索処理のためにベイズ最適化ルーチンをデフォルトで実行します。このチュートリアルでは、実行時間を短縮するためにランダム検索アプローチを使用します。パラメータは、検証データセットにおけるモデルの AUC パフォーマンスに基づいてチューニングされます。

image.png

動作確認済みコード(2024/1/27)

objective_metric_name = "validation:auc"

# Setting up tuner object
tuner_config_dict = {
                     "estimator" : xgb_estimator,
                     "max_jobs" : 5,
                     "max_parallel_jobs" : 2,
                     "objective_metric_name" : objective_metric_name,
                     "hyperparameter_ranges" : hyperparameter_ranges,
                     "base_tuning_job_name" : tuning_job_name_prefix,
                     "strategy" : "Random"
                    }
tuner = HyperparameterTuner(**tuner_config_dict)

SageMaker HyperparameterTunerオブジェクトを設定しています。

目的関数の設定

  • objective_metric_name = "validation:auc": チューニングの対象となる目的関数(評価指標)を設定します。ここでは、検証データのAUCが対象です。

Tunerオブジェクトの設定

  • tuner_config_dict: ハイパーパラメーターチューニングのための設定情報が辞書としてまとめられています。

    • estimator: チューニングするEstimatorオブジェクト(ここではXGBoost Estimator)を指定します。
    • max_jobs: チューニングジョブの最大数を指定します。
    • max_parallel_jobs: 同時に実行するチューニングジョブの最大数を指定します。
    • objective_metric_name: 目的関数(評価指標)の名前を指定します。
    • hyperparameter_ranges: チューニングするハイパーパラメーターの範囲を指定します。
    • base_tuning_job_name: チューニングジョブのベース名を指定します。
    • strategy: チューニングの戦略を指定します。ここでは"Random"戦略が指定されています。
  • tuner = HyperparameterTuner(**tuner_config_dict): HyperparameterTunerオブジェクトを構築します。

【実行後の画面】
image.png

チューナーオブジェクトの fit() メソッドを呼び出すと、ハイパーパラメータのチューニングジョブを起動することができます。チューナーのフィットのために、異なる入力チャネルを指定することができます。このチュートリアルでは、トレーニングと検証のチャネルを提供します。以下のコードブロックをコピーして貼り付け、ハイパーパラメータチューニングジョブを起動します。

image.png

動作確認済みコード(2024/1/27)

# Setting the input channels for tuning job
s3_input_train = TrainingInput(s3_data="s3://{}/{}".format(read_bucket, train_data_key), content_type="csv", s3_data_type="S3Prefix")
s3_input_validation = (TrainingInput(s3_data="s3://{}/{}".format(read_bucket, validation_data_key), 
                                    content_type="csv", s3_data_type="S3Prefix")
                      )

tuner.fit(inputs={"train": s3_input_train, "validation": s3_input_validation}, include_cls_metadata=False)
tuner.wait()

ハイパーパラメーターチューニングジョブのための入力チャネルを設定し、Tunerを使用してトレーニングジョブを開始しています。指定されたハイパーパラメーターの範囲で複数のモデルがトレーニングされ、最良のモデルが選択されます。

入力チャネルの設定

  • s3_input_trains3_input_validationは、トレーニングデータと検証データのS3入力チャネルを設定しています。
  • TrainingInputクラスを使用して、データのS3パス、コンテンツタイプ("csv")、データのS3タイプ("S3Prefix")を指定します。
  • トレーニングデータと検証データの入力チャネルがそれぞれs3_input_trainおよびs3_input_validationに設定されています。

Tunerのフィットと待機

  • tuner.fit(inputs={"train": s3_input_train, "validation": s3_input_validation}, include_cls_metadata=False): Tunerを使用してハイパーパラメーターチューニングジョブを開始します。トレーニングデータと検証データの入力チャネルを指定します。
  • include_cls_metadata=False: クラスメタデータを含まないように指定します。
  • tuner.wait(): チューニングジョブが完了するまで待機します。

【実行後の画面】
image.png

image.png

チューニングが完了すると、
結果のサマリーにアクセスすることができます。以下のコードブロックをコピーして貼り付けると、チューニングジョブの結果がパフォーマンスの降順で並んだ pandas データフレームに取得されます。

image.png

動作確認済みコード(2024/1/27)

# Summary of tuning results ordered in descending order of performance
df_tuner = sagemaker.HyperparameterTuningJobAnalytics(tuner.latest_tuning_job.job_name).dataframe()
df_tuner = df_tuner[df_tuner["FinalObjectiveValue"]>-float('inf')].sort_values("FinalObjectiveValue", ascending=False)
df_tuner

チューニング結果の要約を取得し、パフォーマンスの降順で表示するための処理です。
チューニングジョブの結果を取得し、最もパフォーマンスの良いジョブが上位になるようにデータを整形しています。

チューニング結果の取得

  • sagemaker.HyperparameterTuningJobAnalytics(tuner.latest_tuning_job.job_name): チューニングジョブの分析情報を取得するためのメソッドを使用します。tuner.latest_tuning_job.job_nameは、最新のチューニングジョブの名前を指定します。
  • .dataframe(): チューニングジョブの結果をDataFrame形式で取得します。

データのフィルタリングとソート

  • df_tuner[df_tuner["FinalObjectiveValue"] > -float('inf')]: チューニングジョブの最終的な目的の値が無限大でないデータをフィルタリングします。
  • .sort_values("FinalObjectiveValue", ascending=False): 最終的な目的の値を降順でソートします。これにより、最良のパフォーマンスを持つジョブが上位に表示されます。

最も良い性能を発揮したハイパーパラメータの組み合わせを確認することができます。

【実行後の画面】
image.png

出力内容


	colsample_bytree	eta	max_depth	subsample	TrainingJobName	TrainingJobStatus	FinalObjectiveValue	TrainingStartTime	TrainingEndTime	TrainingElapsedTimeSeconds
4	0.705083	0.710442	4.0	0.925637	xgbtune-240127-0152-001-7fd53feb	Completed	0.75	2024-01-27 01:53:40+00:00	2024-01-27 01:55:37+00:00	117.0
3	0.722878	0.784719	4.0	0.938876	xgbtune-240127-0152-002-30e68c27	Completed	0.74	2024-01-27 01:54:15+00:00	2024-01-27 01:56:03+00:00	108.0
1	0.848005	0.784886	5.0	0.885089	xgbtune-240127-0152-004-317dfed2	Completed	0.74	2024-01-27 01:56:12+00:00	2024-01-27 01:56:55+00:00	43.0
0	0.806768	0.598834	2.0	0.767711	xgbtune-240127-0152-005-aaca6bba	Completed	0.73	2024-01-27 01:56:58+00:00	2024-01-27 01:57:40+00:00	42.0
2	0.884716	0.413005	2.0	0.838182	xgbtune-240127-0152-003-c67d11a3	Completed	0.63	2024-01-27 01:56:08+00:00	2024-01-27 01:56:50+00:00	42.0

これは、ハイパーパラメーターチューニングの結果のサマリーで、各トレーニングジョブの主な情報が表示されています。列の意味は以下の通りです。

  • colsample_bytree: ツリーを構築する際の特徴量のサンプリング割合。
  • eta: 学習率。新しいモデルが以前のモデルの誤差をどれだけ修正するかを制御します。
  • max_depth: ツリーの深さの最大値。
  • subsample: トレーニングデータのサンプリング割合。
  • TrainingJobName: トレーニングジョブの名前。
  • TrainingJobStatus: トレーニングジョブのステータス(例: Completed)。
  • FinalObjectiveValue: チューニングジョブの最終的な目的の値(評価メトリクス)。
  • TrainingStartTime: トレーニングが開始された時刻。
  • TrainingEndTime: トレーニングが終了した時刻。
  • TrainingElapsedTimeSeconds: トレーニングが終了するまでの経過時間(秒)。

表を見ると、各トレーニングジョブのハイパーパラメーターセットとそのパフォーマンスが表示されています。最終的な目的の値が高いほど、モデルの性能が良いことを示しています。例えば、行4のモデル(xgbtune-240127-0152-001-7fd53feb)は最終的な目的の値が0.75で、他のモデルよりも高い性能を持っています。

ステップ 4: モデルのバイアスをチェックし、SageMaker Clarify を使ってモデルの予測を説明する

トレーニング済みモデルができたら、デプロイする前に、モデルやデータに固有のバイアスがないかどうかを理解することが重要です。モデルの予測はバイアスの原因となり得ます (例えば、あるグループに対して別のグループよりも頻繁に否定的な結果をもたらすような予測をする場合など)。SageMaker Clarify では、トレーニング済みモデルがどのように予測を行うのか、特徴量帰属アプローチを用いて説明します。このチュートリアルでは、モデルの説明可能性のために、トレーニング後のバイアスメトリクスと SHAP の値に焦点を当てます。具体的には、以下のような一般的なタスクがカバーされています。

データおよびモデルのバイアス検出
特徴量重要度値を用いたモデルの説明可能性
単一データサンプルの特徴量と局所的説明の影響

SageMaker Clarify がモデルバイアス検出を行うには、SageMaker モデルが必要です。SageMaker Clarify はこれを分析の一部としてエフェメラルエンドポイントにデプロイします。このエンドポイントは、SageMaker Clarify による分析が完了すると削除されます。以下のコードブロックをコピーして貼り付け、チューニングジョブから特定された最適なトレーニングジョブから SageMaker モデルを作成します。

image.png

動作確認済みコード(2024/1/27)

tuner_job_info = sagemaker_client.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=tuner.latest_tuning_job.job_name)

model_matches = sagemaker_client.list_models(NameContains=xgb_model_name)["Models"]

if not model_matches:
    _ = sess.create_model_from_job(
            name=xgb_model_name,
            training_job_name=tuner_job_info['BestTrainingJob']["TrainingJobName"],
            role=sagemaker_role,
            image_uri=tuner_job_info['TrainingJobDefinition']["AlgorithmSpecification"]["TrainingImage"]
            )
else:

    print(f"Model {xgb_model_name} already exists.")

ハイパーパラメーターチューニングジョブの情報を取得し、トレーニングされたモデルを作成します。ハイパーパラメーターチューニングが終了した後、最も優れたモデルを取得し、新しいモデルとしてSageMakerに登録するかどうかを判断します。

ハイパーパラメーターチューニングジョブの情報取得

  • tuner_job_info = sagemaker_client.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=tuner.latest_tuning_job.job_name): チューニングジョブの詳細情報を取得します。

モデルの存在確認

  • model_matches = sagemaker_client.list_models(NameContains=xgb_model_name)["Models"]: 指定されたモデル名を持つモデルが既に存在するかどうかを確認します。

モデルの作成またはスキップ

  • if not model_matches:: モデルが存在しない場合、新しいモデルを作成します。
    • sess.create_model_from_job(...): トレーニングジョブからモデルを作成します。
      • name: 作成するモデルの名前。
      • training_job_name: トレーニングジョブの名前。
      • role: SageMakerが使用するIAMロール。
      • image_uri: トレーニングに使用されたアルゴリズムのイメージURI。
  • else:: モデルが既に存在する場合、メッセージを表示します。

【実行後の画面】
image.png

バイアス検出を実行するために、SageMaker Clarify では複数のコンポーネント構成が設定されていることが予想されます。詳しくは Amazon SageMaker Clarify をご覧ください。このチュートリアルでは、標準的な構成とは別に、ターゲットが顧客の性別に基づいた値に偏っているかどうかをチェックすることによって、データが統計的に女性に偏っているかどうかを検出するように SageMaker Clarify を設定します。以下のコードをコピーして貼り付け、SageMaker Clarify 構成を設定します。

バイアス検出を実行するために、SageMaker Clarify では複数のコンポーネント構成が設定されていることが予想されます。詳しくは Amazon SageMaker Clarify をご覧ください。このチュートリアルでは、標準的な構成とは別に、ターゲットが顧客の性別に基づいた値に偏っているかどうかをチェックすることによって、データが統計的に女性に偏っているかどうかを検出するように SageMaker Clarify を設定します。以下のコードをコピーして貼り付け、SageMaker Clarify 構成を設定します。

image.png

動作確認済みコード(2024/1/27)

train_df = pd.read_csv(train_data_uri)
train_df_cols = train_df.columns.to_list()

clarify_processor = sagemaker.clarify.SageMakerClarifyProcessor(
    role=sagemaker_role,
    instance_count=clarify_instance_count,
    instance_type=clarify_instance_type,
    sagemaker_session=sess,
)

# Data config
bias_data_config = sagemaker.clarify.DataConfig(
    s3_data_input_path=train_data_uri,
    s3_output_path=bias_report_output_uri,
    label="fraud",
    headers=train_df_cols,
    dataset_type="text/csv",
)

# Model config
model_config = sagemaker.clarify.ModelConfig(
    model_name=xgb_model_name,
    instance_type=train_instance_type,
    instance_count=1,
    accept_type="text/csv",
)

# Model predictions config to get binary labels from probabilities
predictions_config = sagemaker.clarify.ModelPredictedLabelConfig(probability_threshold=0.5)

# Bias config
bias_config = sagemaker.clarify.BiasConfig(
    label_values_or_threshold=[0],
    facet_name="customer_gender_female",
    facet_values_or_threshold=[1],
)

# Run Clarify job
clarify_processor.run_bias(
    data_config=bias_data_config,
    bias_config=bias_config,
    model_config=model_config,
    model_predicted_label_config=predictions_config,
    pre_training_methods=["CI"],
    post_training_methods=["DPPL"])

clarify_bias_job_name = clarify_processor.latest_job.name

Amazon SageMaker Clarifyを使用してモデルのバイアスを評価するためのプロセスを設定しています。
SageMaker Clarifyを使用してデータバイアスの分析を行うための手順を示しています。

データの読み込み

  • train_df = pd.read_csv(train_data_uri): トレーニングデータをpandasのデータフレームとして読み込みます。
  • train_df_cols = train_df.columns.to_list(): トレーニングデータの列名を取得します。

Clarify Processorの設定

  • clarify_processor = sagemaker.clarify.SageMakerClarifyProcessor(...): SageMaker Clarify Processorを設定します。これは、データバイアス分析を実行するためのオブジェクトです。

データの設定

  • bias_data_config: Clarifyのデータ構成を設定します。データの場所、ラベル列、ヘッダー、データセットの種類を指定します。

モデルの設定

  • model_config: Clarifyのモデル構成を設定します。モデルの名前、インスタンスのタイプと数、受け入れるデータのタイプを指定します。

モデル予測の設定

  • predictions_config: Clarifyのモデル予測構成を設定します。ここでは、確率から二値のラベルを取得するように設定されています。

データバイアスの設定

  • bias_config: Clarifyのデータバイアス構成を設定します。ラベルの値や閾値、ファセット名やファセットの値を指定します。

Clarifyジョブの実行

  • clarify_processor.run_bias(...): Clarifyジョブを実行します。データ構成、データバイアス構成、モデル構成、モデル予測構成、事前トレーニングおよび事後トレーニングの方法を指定します。

Clarifyジョブ名の取得

  • clarify_bias_job_name = clarify_processor.latest_job.name: 最後に実行したClarifyジョブの名前を取得します。これは後で結果を確認するために使用できます。

【実行後の画面】
image.png

image.png

出力内容


INFO:sagemaker:Creating processing-job with name Clarify-Bias-2024-01-27-02-17-43-968
..........................................2024-01-27 02:24:45,245 logging.conf not found when configuring logging, using default logging configuration.
2024-01-27 02:24:45,246 Starting SageMaker Clarify Processing job
2024-01-27 02:24:45,246 Analysis config path: /opt/ml/processing/input/config/analysis_config.json
2024-01-27 02:24:45,246 Analysis result path: /opt/ml/processing/output
2024-01-27 02:24:45,246 This host is algo-1.
2024-01-27 02:24:45,247 This host is the leader.
2024-01-27 02:24:45,247 Number of hosts in the cluster is 1.
2024-01-27 02:24:45,529 Running Python / Pandas based analyzer.
2024-01-27 02:24:45,529 Dataset type: text/csv uri: /opt/ml/processing/input/data
2024-01-27 02:24:45,542 Loading dataset...
/usr/local/lib/python3.9/site-packages/analyzer/data_loading/csv_data_loader.py:336: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
  df = df.append(df_tmp, ignore_index=True)
2024-01-27 02:24:45,569 Loaded dataset. Dataset info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4009 entries, 0 to 4008
Data columns (total 48 columns):
 #   Column                           Non-Null Count  Dtype  
---  ------                           --------------  -----  
 0   num_vehicles_involved            4009 non-null   int64  
 1   num_injuries                     4009 non-null   int64  
 2   num_witnesses                    4009 non-null   int64  
 3   police_report_available          4009 non-null   int64  
 4   injury_claim                     4009 non-null   int64  
 5   vehicle_claim                    4009 non-null   float64
 6   total_claim_amount               4009 non-null   float64
 7   incident_month                   4009 non-null   int64  
 8   incident_day                     4009 non-null   int64  
 9   incident_dow                     4009 non-null   int64  
 10  incident_hour                    4009 non-null   int64  
 11  customer_age                     4009 non-null   int64  
 12  months_as_customer               4009 non-null   int64  
 13  num_claims_past_year             4009 non-null   int64  
 14  num_insurers_past_5_years        4009 non-null   int64  
 15  policy_deductable                4009 non-null   int64  
 16  policy_annual_premium            4009 non-null   int64  
 17  policy_liability                 4009 non-null   int64  
 18  customer_education               4009 non-null   int64  
 19  auto_year                        4009 non-null   int64  
 20  driver_relationship_other        4009 non-null   int64  
 21  driver_relationship_child        4009 non-null   int64  
 22  driver_relationship_spouse       4009 non-null   int64  
 23  driver_relationship_na           4009 non-null   int64  
 24  driver_relationship_self         4009 non-null   int64  
 25  incident_type_collision          4009 non-null   int64  
 26  incident_type_break-in           4009 non-null   int64  
 27  incident_type_theft              4009 non-null   int64  
 28  collision_type_rear              4009 non-null   int64  
 29  collision_type_side              4009 non-null   int64  
 30  collision_type_na                4009 non-null   int64  
 31  collision_type_front             4009 non-null   int64  
 32  incident_severity_totaled        4009 non-null   int64  
 33  incident_severity_major          4009 non-null   int64  
 34  incident_severity_minor          4009 non-null   int64  
 35  authorities_contacted_fire       4009 non-null   int64  
 36  authorities_contacted_none       4009 non-null   int64  
 37  authorities_contacted_police     4009 non-null   int64  
 38  authorities_contacted_ambulance  4009 non-null   int64  
 39  policy_state_ca                  4009 non-null   int64  
 40  policy_state_az                  4009 non-null   int64  
 41  policy_state_nv                  4009 non-null   int64  
 42  policy_state_id                  4009 non-null   int64  
 43  policy_state_wa                  4009 non-null   int64  
 44  policy_state_or                  4009 non-null   int64  
 45  customer_gender_other            4009 non-null   int64  
 46  customer_gender_male             4009 non-null   int64  
 47  customer_gender_female           4009 non-null   int64  
dtypes: float64(2), int64(46)
memory usage: 1.5 MB
2024-01-27 02:24:45,670 Spinning up shadow endpoint
2024-01-27 02:24:45,670 Creating endpoint-config with name sm-clarify-config-1706322285-5526
2024-01-27 02:24:46,150 Creating endpoint: 'sm-clarify-fraud-detect-xgb-model-1706322286-1442'
2024-01-27 02:24:46,220 No endpoints ruleset found for service sagemaker-internal, falling back to legacy endpoint routing.
2024-01-27 02:24:46,806 Using endpoint name: sm-clarify-fraud-detect-xgb-model-1706322286-1442
2024-01-27 02:24:46,807 Waiting for endpoint ...
2024-01-27 02:24:46,807 Checking endpoint status:
Legend:
(OutOfService: x, Creating: -, Updating: -, InService: !, RollingBack: <, Deleting: o, Failed: *)
2024-01-27 02:28:47,568 Endpoint is in service after 241 seconds
2024-01-27 02:28:47,569 Endpoint ready.
2024-01-27 02:28:47,569 ======================================
2024-01-27 02:28:47,569 Calculating post-training bias metrics
2024-01-27 02:28:47,569 ======================================
2024-01-27 02:28:47,569 Getting predictions from the endpoint
2024-01-27 02:28:48,419 We assume a prediction above 0.500 indicates 1 and below or equal indicates 0.
2024-01-27 02:28:48,420 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:28:48,423 Column customer_gender_female with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
/usr/local/lib/python3.9/site-packages/smclarify/bias/report.py:591: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only.
  df = df.drop(facet_column.name, 1)
2024-01-27 02:28:48,425 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:28:48,427 Column None with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:28:48,429 Calculated global analysis with predictor
2024-01-27 02:28:48,429 Stop using endpoint: sm-clarify-fraud-detect-xgb-model-1706322286-1442
2024-01-27 02:28:48,429 Deleting endpoint configuration with name: sm-clarify-config-1706322285-5526
2024-01-27 02:28:48,586 Deleting endpoint with name: sm-clarify-fraud-detect-xgb-model-1706322286-1442
2024-01-27 02:28:48,666 Model endpoint delivered 2.32368 requests per second and a total of 2 requests over 1 seconds
2024-01-27 02:28:48,666 =====================================
2024-01-27 02:28:48,666 Calculating pre-training bias metrics
2024-01-27 02:28:48,666 =====================================
2024-01-27 02:28:48,667 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:28:48,669 Column customer_gender_female with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
/usr/local/lib/python3.9/site-packages/smclarify/bias/report.py:591: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only.
  df = df.drop(facet_column.name, 1)
2024-01-27 02:28:48,671 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:28:48,673 ======================================
2024-01-27 02:28:48,673 Calculating bias statistics for report
2024-01-27 02:28:48,674 ======================================
2024-01-27 02:28:48,674 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:28:48,676 Column customer_gender_female with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
/usr/local/lib/python3.9/site-packages/smclarify/bias/report.py:591: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only.
  df = df.drop(facet_column.name, 1)
2024-01-27 02:28:48,677 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:28:48,679 Column None with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:28:48,684 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:28:48,686 Column None with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:28:48,694 Converting Pandas DataFrame to SparkDataFrame for computing report metadata
---!02:28:50.694 [main] WARN  o.a.hadoop.util.NativeCodeLoader - Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
02:28:56.726 [Thread-4] WARN  o.a.spark.sql.catalyst.util.package - Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
#015[Stage 0:>                                                          (0 + 4) / 4]#015#015                                                                                #015#015[Stage 3:>                                                          (0 + 4) / 4]#015#015                                                                                #0152024-01-27 02:29:03,689 Calculated global analysis without predictor
2024-01-27 02:29:03,690 Stop using endpoint: None
2024-01-27 02:29:04,598 ['jupyter', 'nbconvert', '--to', 'html', '--output', '/opt/ml/processing/output/report.html', '/opt/ml/processing/output/report.ipynb', '--template', 'sagemaker-xai']
[NbConvertApp] Converting notebook /opt/ml/processing/output/report.ipynb to html
[NbConvertApp] Writing 744861 bytes to /opt/ml/processing/output/report.html
2024-01-27 02:29:05,863 ['wkhtmltopdf', '-q', '--enable-local-file-access', '/opt/ml/processing/output/report.html', '/opt/ml/processing/output/report.pdf']
2024-01-27 02:29:06,990 Collected analyses: 
{
    "version": "1.0",
    "post_training_bias_metrics": {
        "label": "fraud",
        "facets": {
            "customer_gender_female": [
                {
                    "value_or_threshold": "1",
                    "metrics": [
                        {
                            "name": "DPPL",
                            "description": "Difference in Positive Proportions in Predicted Labels (DPPL)",
                            "value": 0.0009633794348478109
                        }
                    ]
                }
            ]
        },
        "label_value_or_threshold": "0"
    },
    "pre_training_bias_metrics": {
        "label": "fraud",
        "facets": {
            "customer_gender_female": [
                {
                    "value_or_threshold": "1",
                    "metrics": [
                        {
                            "name": "CI",
                            "description": "Class Imbalance (CI)",
                            "value": 0.12846096283362435
                        }
                    ]
                }
            ]
        },
        "label_value_or_threshold": "0"
    }
}
2024-01-27 02:29:06,991 exit_message: Completed: SageMaker XAI Analyzer ran successfully


このログは、SageMaker Clarifyを使用してデータバイアス分析を実行した後の詳細な情報を提供しています。ログの内容を解説します:

  • processing-jobの作成 最初の行で、Clarifyバイアス分析のためのprocessing-jobが作成されることが示されています。
    INFO:sagemaker:Creating processing-job with name Clarify-Bias-2024-01-27-02-17-43-968: SageMaker Clarifyの処理ジョブが作成されました。

  • データの読み込み データセットが読み込まれ、その情報が表示されます。読み込まれたデータは、PandasのDataFrameとして表示され、各列の情報(データ型やユニークな値の割合など)が表示されます。

  • Shadow endpointの作成 Clarifyが実行されるためのshadow endpoint(推論エンドポイントのクローン)が作成されます。
    Spinning up shadow endpoint: Clarifyのためにシャドウエンドポイントが作成されています。

  • モデルの推論 作成されたエンドポイントを使用してモデルの推論が行われます。推論結果を使用してバイアス分析が実行されます。

  • エンドポイントの待機
    Waiting for endpoint ...: エンドポイントが起動するのを待っています。

  • エンドポイントの稼働
    Endpoint is in service after 241 seconds: エンドポイントが正常に稼働しました。

  • バイアス統計の計算 バイアス統計の計算が行われ、結果が提供されます。このログでは、DPPL(予測されたラベルの陽性比率の差異)やCI(クラスの不均衡)などのバイアスメトリクスが示されています。
    Calculating post-training bias metrics: トレーニング後のデータバイアスメトリクスが計算されています。

  • Categorical Columnの特定
    Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column: データバイアス分析の際、カテゴリカルな列が特定されています。

  • Pre-training Bias Metricsの計算
    Calculating pre-training bias metrics: トレーニング前のデータバイアスメトリクスが計算されています。

  • Bias Statisticsの計算
    Calculating bias statistics for report: バイアス統計情報が計算されています。

  • Clarifyジョブの結果
    Collected analyses: Clarifyジョブの結果が示されています。post_training_bias_metricsおよびpre_training_bias_metricsの結果が含まれており、それぞれのメトリクスやファセットの詳細が示されています。

  • HTMLレポートの生成
    Writing 744861 bytes to /opt/ml/processing/output/report.html: 分析結果を含むHTMLレポートが生成されています。

  • PDFレポートの生成
    wkhtmltopdf: HTMLレポートからPDFレポートが生成されています。

  • 処理ジョブの完了
    exit_message: Completed: SageMaker XAI Analyzer ran successfully: ジョブが正常に完了しました。

これらのログを通じて、データバイアスの分析が正常に行われ、それに対するさまざまなメトリクスが計算されていることがわかります。

SageMaker Clarify 内では、トレーニング前のメトリクスはデータ内の既存のバイアスを示し、トレーニング後のメトリクスはモデルからの予測のバイアスを示します。SageMaker SDK を使用すると、バイアスをチェックするグループと、考慮するバイアスメトリクスを指定することができます。このチュートリアルでは、Class Imbalance (CI) と予測ラベルの正比例の差 (DPPL) を、それぞれトレーニング前とトレーニング後のバイアス統計の例として使用します。その他のバイアスメトリクスの詳細は、トレーニング前のバイアスの測定とトレーニング後のデータとモデルのバイアスをご覧ください。以下のコードブロックをコピーして貼り付けると、SageMaker Clarify が実行され、バイアスレポートが生成されます。選択したバイアスメトリクスは run_bias メソッドの引数として渡されます。

image.png

動作確認済みコード(2024/1/27)

clarify_processor.run_bias(
    data_config=bias_data_config,                 # データ構成の設定
    bias_config=bias_config,                       # バイアスの構成
    model_config=model_config,                     # モデルの構成
    model_predicted_label_config=predictions_config,  # モデルの予測ラベルの構成
    pre_training_methods=["CI"],                   # 事前トレーニングのバイアスメソッド(ここではClass Imbalance)
    post_training_methods=["DPPL"]                 # 事後トレーニングのバイアスメソッド(ここではDifference in Positive Proportions in Predicted Labels)
)

clarify_bias_job_name = clarify_processor.latest_job.name  # 最新のClarifyジョブの名前を取得

SageMaker Clarify Processorを使用して、バイアスレポートを生成するためのジョブを開始しています。
指定されたデータとモデルに対してバイアス評価が実行され、結果が収集されるという流れが確立されます。

clarify_processor.run_bias()の呼び出し clarify_processorオブジェクトのrun_bias()メソッドを使用して、バイアスの評価処理が開始されます。このメソッドには次のパラメータが渡されています。

  • data_config: データの構成を指定します。この中には、入力データのS3パス、ラベル列の名前、データセットのヘッダー情報などが含まれます。
  • bias_config: バイアスの構成を指定します。ここでは、どの属性に関してバイアスを評価するかを指定しています。
  • model_config: モデルの構成を指定します。この中には、モデルのエンドポイント名やインスタンスタイプなどが含まれます。
  • model_predicted_label_config: モデルが予測するラベルの構成を指定します。ここでは、確率を閾値で2値化する方法を指定しています。
  • pre_training_methods: トレーニング前のバイアス評価手法を指定します。ここでは、CI(Class Imbalance)が使用されています。
  • post_training_methods: トレーニング後のバイアス評価手法を指定します。ここでは、DPPL(Difference in Positive Proportions in Predicted Labels)が使用されています。

ジョブ名の取得 clarify_processor.latest_job.nameを使用して、最後に実行されたClarifyジョブの名前を取得し、clarify_bias_job_name変数に格納されます。

【実行後の画面】
image.png

出力内容


INFO:sagemaker.clarify:Analysis Config: {'dataset_type': 'text/csv', 'headers': ['fraud', 'num_vehicles_involved', 'num_injuries', 'num_witnesses', 'police_report_available', 'injury_claim', 'vehicle_claim', 'total_claim_amount', 'incident_month', 'incident_day', 'incident_dow', 'incident_hour', 'customer_age', 'months_as_customer', 'num_claims_past_year', 'num_insurers_past_5_years', 'policy_deductable', 'policy_annual_premium', 'policy_liability', 'customer_education', 'auto_year', 'driver_relationship_other', 'driver_relationship_child', 'driver_relationship_spouse', 'driver_relationship_na', 'driver_relationship_self', 'incident_type_collision', 'incident_type_break-in', 'incident_type_theft', 'collision_type_rear', 'collision_type_side', 'collision_type_na', 'collision_type_front', 'incident_severity_totaled', 'incident_severity_major', 'incident_severity_minor', 'authorities_contacted_fire', 'authorities_contacted_none', 'authorities_contacted_police', 'authorities_contacted_ambulance', 'policy_state_ca', 'policy_state_az', 'policy_state_nv', 'policy_state_id', 'policy_state_wa', 'policy_state_or', 'customer_gender_other', 'customer_gender_male', 'customer_gender_female'], 'label': 'fraud', 'label_values_or_threshold': [0], 'facet': [{'name_or_index': 'customer_gender_female', 'value_or_threshold': [1]}], 'methods': {'report': {'name': 'report', 'title': 'Analysis Report'}, 'pre_training_bias': {'methods': ['CI']}, 'post_training_bias': {'methods': ['DPPL']}}, 'predictor': {'model_name': 'fraud-detect-xgb-model', 'instance_type': 'ml.m4.xlarge', 'initial_instance_count': 1, 'accept_type': 'text/csv'}, 'probability_threshold': 0.5}
INFO:sagemaker:Creating processing-job with name Clarify-Bias-2024-01-27-02-32-25-869
...........................................2024-01-27 02:39:34,693 logging.conf not found when configuring logging, using default logging configuration.
2024-01-27 02:39:34,694 Starting SageMaker Clarify Processing job
2024-01-27 02:39:34,694 Analysis config path: /opt/ml/processing/input/config/analysis_config.json
2024-01-27 02:39:34,694 Analysis result path: /opt/ml/processing/output
2024-01-27 02:39:34,694 This host is algo-1.
2024-01-27 02:39:34,694 This host is the leader.
2024-01-27 02:39:34,694 Number of hosts in the cluster is 1.
2024-01-27 02:39:34,985 Running Python / Pandas based analyzer.
2024-01-27 02:39:34,985 Dataset type: text/csv uri: /opt/ml/processing/input/data
2024-01-27 02:39:34,998 Loading dataset...
/usr/local/lib/python3.9/site-packages/analyzer/data_loading/csv_data_loader.py:336: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
  df = df.append(df_tmp, ignore_index=True)
2024-01-27 02:39:35,025 Loaded dataset. Dataset info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4009 entries, 0 to 4008
Data columns (total 48 columns):
 #   Column                           Non-Null Count  Dtype  
---  ------                           --------------  -----  
 0   num_vehicles_involved            4009 non-null   int64  
 1   num_injuries                     4009 non-null   int64  
 2   num_witnesses                    4009 non-null   int64  
 3   police_report_available          4009 non-null   int64  
 4   injury_claim                     4009 non-null   int64  
 5   vehicle_claim                    4009 non-null   float64
 6   total_claim_amount               4009 non-null   float64
 7   incident_month                   4009 non-null   int64  
 8   incident_day                     4009 non-null   int64  
 9   incident_dow                     4009 non-null   int64  
 10  incident_hour                    4009 non-null   int64  
 11  customer_age                     4009 non-null   int64  
 12  months_as_customer               4009 non-null   int64  
 13  num_claims_past_year             4009 non-null   int64  
 14  num_insurers_past_5_years        4009 non-null   int64  
 15  policy_deductable                4009 non-null   int64  
 16  policy_annual_premium            4009 non-null   int64  
 17  policy_liability                 4009 non-null   int64  
 18  customer_education               4009 non-null   int64  
 19  auto_year                        4009 non-null   int64  
 20  driver_relationship_other        4009 non-null   int64  
 21  driver_relationship_child        4009 non-null   int64  
 22  driver_relationship_spouse       4009 non-null   int64  
 23  driver_relationship_na           4009 non-null   int64  
 24  driver_relationship_self         4009 non-null   int64  
 25  incident_type_collision          4009 non-null   int64  
 26  incident_type_break-in           4009 non-null   int64  
 27  incident_type_theft              4009 non-null   int64  
 28  collision_type_rear              4009 non-null   int64  
 29  collision_type_side              4009 non-null   int64  
 30  collision_type_na                4009 non-null   int64  
 31  collision_type_front             4009 non-null   int64  
 32  incident_severity_totaled        4009 non-null   int64  
 33  incident_severity_major          4009 non-null   int64  
 34  incident_severity_minor          4009 non-null   int64  
 35  authorities_contacted_fire       4009 non-null   int64  
 36  authorities_contacted_none       4009 non-null   int64  
 37  authorities_contacted_police     4009 non-null   int64  
 38  authorities_contacted_ambulance  4009 non-null   int64  
 39  policy_state_ca                  4009 non-null   int64  
 40  policy_state_az                  4009 non-null   int64  
 41  policy_state_nv                  4009 non-null   int64  
 42  policy_state_id                  4009 non-null   int64  
 43  policy_state_wa                  4009 non-null   int64  
 44  policy_state_or                  4009 non-null   int64  
 45  customer_gender_other            4009 non-null   int64  
 46  customer_gender_male             4009 non-null   int64  
 47  customer_gender_female           4009 non-null   int64  
dtypes: float64(2), int64(46)
memory usage: 1.5 MB
2024-01-27 02:39:35,126 Spinning up shadow endpoint
2024-01-27 02:39:35,127 Creating endpoint-config with name sm-clarify-config-1706323175-2c97
2024-01-27 02:39:35,613 Creating endpoint: 'sm-clarify-fraud-detect-xgb-model-1706323175-d968'
2024-01-27 02:39:35,684 No endpoints ruleset found for service sagemaker-internal, falling back to legacy endpoint routing.
2024-01-27 02:39:36,286 Using endpoint name: sm-clarify-fraud-detect-xgb-model-1706323175-d968
2024-01-27 02:39:36,287 Waiting for endpoint ...
2024-01-27 02:39:36,287 Checking endpoint status:
Legend:
(OutOfService: x, Creating: -, Updating: -, InService: !, RollingBack: <, Deleting: o, Failed: *)
2024-01-27 02:43:37,006 Endpoint is in service after 241 seconds
2024-01-27 02:43:37,006 Endpoint ready.
2024-01-27 02:43:37,007 ======================================
2024-01-27 02:43:37,007 Calculating post-training bias metrics
2024-01-27 02:43:37,007 ======================================
2024-01-27 02:43:37,007 Getting predictions from the endpoint
2024-01-27 02:43:37,856 We assume a prediction above 0.500 indicates 1 and below or equal indicates 0.
2024-01-27 02:43:37,856 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:43:37,859 Column customer_gender_female with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
/usr/local/lib/python3.9/site-packages/smclarify/bias/report.py:591: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only.
  df = df.drop(facet_column.name, 1)
2024-01-27 02:43:37,861 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:43:37,863 Column None with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:43:37,866 Calculated global analysis with predictor
2024-01-27 02:43:37,866 Stop using endpoint: sm-clarify-fraud-detect-xgb-model-1706323175-d968
2024-01-27 02:43:37,866 Deleting endpoint configuration with name: sm-clarify-config-1706323175-2c97
2024-01-27 02:43:38,026 Deleting endpoint with name: sm-clarify-fraud-detect-xgb-model-1706323175-d968
2024-01-27 02:43:38,115 Model endpoint delivered 2.32758 requests per second and a total of 2 requests over 1 seconds
2024-01-27 02:43:38,115 =====================================
2024-01-27 02:43:38,115 Calculating pre-training bias metrics
2024-01-27 02:43:38,115 =====================================
2024-01-27 02:43:38,116 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:43:38,118 Column customer_gender_female with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
/usr/local/lib/python3.9/site-packages/smclarify/bias/report.py:591: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only.
  df = df.drop(facet_column.name, 1)
2024-01-27 02:43:38,120 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:43:38,122 ======================================
2024-01-27 02:43:38,122 Calculating bias statistics for report
2024-01-27 02:43:38,122 ======================================
2024-01-27 02:43:38,123 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:43:38,124 Column customer_gender_female with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
/usr/local/lib/python3.9/site-packages/smclarify/bias/report.py:591: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only.
  df = df.drop(facet_column.name, 1)
2024-01-27 02:43:38,126 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:43:38,128 Column None with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:43:38,133 Column fraud with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:43:38,134 Column None with data uniqueness fraction 0.0004988775255674732 is classifed as a CATEGORICAL column
2024-01-27 02:43:38,143 Converting Pandas DataFrame to SparkDataFrame for computing report metadata
---!02:43:40.141 [main] WARN  o.a.hadoop.util.NativeCodeLoader - Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
02:43:46.009 [Thread-4] WARN  o.a.spark.sql.catalyst.util.package - Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
#015[Stage 0:>                                                          (0 + 4) / 4]#015#015                                                                                #015#015[Stage 3:>                                                          (0 + 4) / 4]#015#015                                                                                #0152024-01-27 02:43:52,777 Calculated global analysis without predictor
2024-01-27 02:43:52,778 Stop using endpoint: None
2024-01-27 02:43:54,011 ['jupyter', 'nbconvert', '--to', 'html', '--output', '/opt/ml/processing/output/report.html', '/opt/ml/processing/output/report.ipynb', '--template', 'sagemaker-xai']
[NbConvertApp] Converting notebook /opt/ml/processing/output/report.ipynb to html
[NbConvertApp] Writing 744861 bytes to /opt/ml/processing/output/report.html
2024-01-27 02:43:55,313 ['wkhtmltopdf', '-q', '--enable-local-file-access', '/opt/ml/processing/output/report.html', '/opt/ml/processing/output/report.pdf']
2024-01-27 02:43:56,419 Collected analyses: 
{
    "version": "1.0",
    "post_training_bias_metrics": {
        "label": "fraud",
        "facets": {
            "customer_gender_female": [
                {
                    "value_or_threshold": "1",
                    "metrics": [
                        {
                            "name": "DPPL",
                            "description": "Difference in Positive Proportions in Predicted Labels (DPPL)",
                            "value": 0.0009633794348478109
                        }
                    ]
                }
            ]
        },
        "label_value_or_threshold": "0"
    },
    "pre_training_bias_metrics": {
        "label": "fraud",
        "facets": {
            "customer_gender_female": [
                {
                    "value_or_threshold": "1",
                    "metrics": [
                        {
                            "name": "CI",
                            "description": "Class Imbalance (CI)",
                            "value": 0.12846096283362435
                        }
                    ]
                }
            ]
        },
        "label_value_or_threshold": "0"
    }
}
2024-01-27 02:43:56,420 exit_message: Completed: SageMaker XAI Analyzer ran successfully



このログは、SageMaker Clarifyを使用してバイアスを評価するために実行されたジョブの詳細を示しています。内容を解説します:

Analysis Config

  • dataset_type: 使用されるデータセットのタイプ(ここではtext/csv)。
  • headers: データセットのヘッダー情報。
  • label: バイアスを評価するためのラベルカラム(ここでは 'fraud')。
  • label_values_or_threshold: ラベルの値または閾値(ここでは [0])。
  • facet: バイアスを評価するための追加の要因(ここでは customer_gender_female)。
  • methods: バイアスの計測方法(ここでは 'CI' および 'DPPL')。
  • predictor: 使用される予測モデルの構成情報。
  • probability_threshold: 予測ラベルの確率の閾値(ここでは 0.5)。

Processing Job Start

  • Clarifyのプロセッシングジョブが作成され、ジョブ名が表示されます(ここでは Clarify-Bias-2024-01-27-02-32-25-869)。

Loading Dataset

  • データセットが読み込まれ、その概要が表示されます。データセットはCSV形式で、48の列と4009の行があります。

Spinning up Shadow Endpoint

  • シャドウエンドポイントが生成され、エンドポイント構成が作成されます。

Creating Endpoint

  • バイアス評価のためのモデルエンドポイントが作成されます。

Calculating Post-training Bias Metrics

  • 事後トレーニングのバイアスメトリクスが計算され、その結果が表示されます(ここでは 'DPPL' メトリクス)。

Calculating Pre-training Bias Metrics

  • 事前トレーニングのバイアスメトリクスが計算され、その結果が表示されます(ここでは 'CI' メトリクス)。

Calculating Bias Statistics for Report

  • レポート用のバイアス統計が計算されます。

Analyses Report Generation

  • レポートが生成され、HTMLおよびPDF形式に変換されます。ファイルの出力先は /opt/ml/processing/output/report.html および /opt/ml/processing/output/report.pdf です。

Bias Metrics Results

  • バイアスメトリクスの計算結果が表示されます。それには事前トレーニングおよび事後トレーニングのバイアスメトリクスが含まれます。

PDF レポートでは、トレーニング前とトレーニング後のバイアスメトリクスに基づき、データセットには顧客の性別の特微量に関してクラスの不均衡があるように見えます。このようなアンバランスは、SMOTE などの技術を適用してトレーニングデータセットを再作成することで修正することができます。また、SageMaker Data Wrangler を使用して、サービス内で利用可能な SMOTE を含む複数のオプションのいずれかを指定し、トレーニングデータセットのバランスをとることができます。
image.png

動作確認済みコード(2024/1/27)

# Copy bias report and view locally
!aws s3 cp s3://{write_bucket}/{write_prefix}/clarify-output/bias/report.pdf ./clarify_bias_output.pdf

【実行後の画面】
image.png

image.png

出力されたレポート
image.png

image.png

image.png

image.png

image.png

image.png

SageMaker Clarify では、データのバイアスに加えて、トレーニング済みモデルを分析し、特微量重要度に基づいてモデルの説明可能性レポートを作成することもできます。SageMaker Clarify は SHAP 値を使用して、各入力特微量が最終的な予測に与える寄与度を説明します。以下のコードブロックをコピーして貼り付け、モデルの説明可能性分析を構成および実行します。

image.png

動作確認済みコード(2024/1/27)

explainability_data_config = sagemaker.clarify.DataConfig(
    s3_data_input_path=train_data_uri,
    s3_output_path=explainability_report_output_uri,
    label="fraud",
    headers=train_df_cols,
    dataset_type="text/csv",
)

# Use mean of train dataset as baseline data point
shap_baseline = [list(train_df.drop(["fraud"], axis=1).mean())]

shap_config = sagemaker.clarify.SHAPConfig(
    baseline=shap_baseline,
    num_samples=500,
    agg_method="mean_abs",
    save_local_shap_values=True,
)

clarify_processor.run_explainability(
    data_config=explainability_data_config,
    model_config=model_config,
    explainability_config=shap_config
)

このコードは、SageMaker Clarifyを使用してモデルの説明可能性を評価しています。評価結果は、指定されたS3パスに保存されます。

explainability_data_configの設定

  • s3_data_input_path: 学習データのS3パス。
  • s3_output_path: 説明可能性のレポートの出力先S3パス。
  • label: 予測の対象となるラベルの列名(ここでは "fraud")。
  • headers: 学習データの列のリスト。
  • dataset_type: データセットのタイプ(ここでは "text/csv")。

ベースラインの設定

  • shap_baseline: SHAP値を計算するためのベースラインデータ。ここでは、学習データの各特徴量の平均値が使用されています。

SHAPの設定

  • baseline: SHAP値の計算の基準となるデータ(ベースライン)。
  • num_samples: SHAP値を計算するためのサンプル数。
  • agg_method: SHAP値を集計する方法(ここでは "mean_abs"、平均の絶対値)。
  • save_local_shap_values: SHAP値をローカルに保存するかどうか。

clarify_processor.run_explainabilityの呼び出し

  • data_config: 説明可能性のデータ設定。
  • model_config: モデルの設定。
  • explainability_config: SHAPの設定。

【実行後の画面】

image.png

image.png

image.png

出力内容


INFO:sagemaker.clarify:Analysis Config: {'dataset_type': 'text/csv', 'headers': ['fraud', 'num_vehicles_involved', 'num_injuries', 'num_witnesses', 'police_report_available', 'injury_claim', 'vehicle_claim', 'total_claim_amount', 'incident_month', 'incident_day', 'incident_dow', 'incident_hour', 'customer_age', 'months_as_customer', 'num_claims_past_year', 'num_insurers_past_5_years', 'policy_deductable', 'policy_annual_premium', 'policy_liability', 'customer_education', 'auto_year', 'driver_relationship_other', 'driver_relationship_child', 'driver_relationship_spouse', 'driver_relationship_na', 'driver_relationship_self', 'incident_type_collision', 'incident_type_break-in', 'incident_type_theft', 'collision_type_rear', 'collision_type_side', 'collision_type_na', 'collision_type_front', 'incident_severity_totaled', 'incident_severity_major', 'incident_severity_minor', 'authorities_contacted_fire', 'authorities_contacted_none', 'authorities_contacted_police', 'authorities_contacted_ambulance', 'policy_state_ca', 'policy_state_az', 'policy_state_nv', 'policy_state_id', 'policy_state_wa', 'policy_state_or', 'customer_gender_other', 'customer_gender_male', 'customer_gender_female'], 'label': 'fraud', 'predictor': {'model_name': 'fraud-detect-xgb-model', 'instance_type': 'ml.m4.xlarge', 'initial_instance_count': 1, 'accept_type': 'text/csv'}, 'methods': {'report': {'name': 'report', 'title': 'Analysis Report'}, 'shap': {'use_logit': False, 'save_local_shap_values': True, 'baseline': [[2.1085058618109254, 0.5584933898727862, 0.8685457720129708, 0.4220503866300823, 24257.121476677476, 17169.351123437555, 41426.472600115034, 6.726365677226241, 15.585682215016213, 2.645048640558743, 11.722624095784484, 44.15714642055375, 98.60688450985283, 0.08730356697430781, 1.4130705911698678, 751.0725866799701, 2925.3305063606886, 1.118233973559491, 2.531304564729359, 2015.7251184834124, 0.04065851833374907, 0.04489897730107259, 0.08505861810925418, 0.14342728860064854, 0.6859565976552756, 0.8565727113993514, 0.09553504614617112, 0.04789224245447742, 0.21900723372412073, 0.21052631578947367, 0.14342728860064854, 0.427039161885757, 0.23547019206784733, 0.34846595160888, 0.41606385632327264, 0.024195559990022448, 0.2432027937141432, 0.7031678722873534, 0.029433774008480917, 0.6303317535545023, 0.107009229234223, 0.04190571214766775, 0.027937141431778497, 0.12297331005238214, 0.06984285357944625, 0.0177101521576453, 0.4622100274382639, 0.43576951858318785]], 'num_samples': 500, 'agg_method': 'mean_abs'}}}
INFO:sagemaker:Creating processing-job with name Clarify-Explainability-2024-01-27-02-52-29-233
..........................................2024-01-27 02:59:26,916 logging.conf not found when configuring logging, using default logging configuration.
2024-01-27 02:59:26,917 Starting SageMaker Clarify Processing job
2024-01-27 02:59:26,917 Analysis config path: /opt/ml/processing/input/config/analysis_config.json
2024-01-27 02:59:26,917 Analysis result path: /opt/ml/processing/output
2024-01-27 02:59:26,918 This host is algo-1.
2024-01-27 02:59:26,918 This host is the leader.
2024-01-27 02:59:26,918 Number of hosts in the cluster is 1.
2024-01-27 02:59:27,201 Running Python / Pandas based analyzer.
2024-01-27 02:59:27,202 Dataset type: text/csv uri: /opt/ml/processing/input/data
2024-01-27 02:59:27,214 Loading dataset...
/usr/local/lib/python3.9/site-packages/analyzer/data_loading/csv_data_loader.py:336: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.
  df = df.append(df_tmp, ignore_index=True)
2024-01-27 02:59:27,241 Loaded dataset. Dataset info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4009 entries, 0 to 4008
Data columns (total 48 columns):
 #   Column                           Non-Null Count  Dtype  
---  ------                           --------------  -----  
 0   num_vehicles_involved            4009 non-null   int64  
 1   num_injuries                     4009 non-null   int64  
 2   num_witnesses                    4009 non-null   int64  
 3   police_report_available          4009 non-null   int64  
 4   injury_claim                     4009 non-null   int64  
 5   vehicle_claim                    4009 non-null   float64
 6   total_claim_amount               4009 non-null   float64
 7   incident_month                   4009 non-null   int64  
 8   incident_day                     4009 non-null   int64  
 9   incident_dow                     4009 non-null   int64  
 10  incident_hour                    4009 non-null   int64  
 11  customer_age                     4009 non-null   int64  
 12  months_as_customer               4009 non-null   int64  
 13  num_claims_past_year             4009 non-null   int64  
 14  num_insurers_past_5_years        4009 non-null   int64  
 15  policy_deductable                4009 non-null   int64  
 16  policy_annual_premium            4009 non-null   int64  
 17  policy_liability                 4009 non-null   int64  
 18  customer_education               4009 non-null   int64  
 19  auto_year                        4009 non-null   int64  
 20  driver_relationship_other        4009 non-null   int64  
 21  driver_relationship_child        4009 non-null   int64  
 22  driver_relationship_spouse       4009 non-null   int64  
 23  driver_relationship_na           4009 non-null   int64  
 24  driver_relationship_self         4009 non-null   int64  
 25  incident_type_collision          4009 non-null   int64  
 26  incident_type_break-in           4009 non-null   int64  
 27  incident_type_theft              4009 non-null   int64  
 28  collision_type_rear              4009 non-null   int64  
 29  collision_type_side              4009 non-null   int64  
 30  collision_type_na                4009 non-null   int64  
 31  collision_type_front             4009 non-null   int64  
 32  incident_severity_totaled        4009 non-null   int64  
 33  incident_severity_major          4009 non-null   int64  
 34  incident_severity_minor          4009 non-null   int64  
 35  authorities_contacted_fire       4009 non-null   int64  
 36  authorities_contacted_none       4009 non-null   int64  
 37  authorities_contacted_police     4009 non-null   int64  
 38  authorities_contacted_ambulance  4009 non-null   int64  
 39  policy_state_ca                  4009 non-null   int64  
 40  policy_state_az                  4009 non-null   int64  
 41  policy_state_nv                  4009 non-null   int64  
 42  policy_state_id                  4009 non-null   int64  
 43  policy_state_wa                  4009 non-null   int64  
 44  policy_state_or                  4009 non-null   int64  
 45  customer_gender_other            4009 non-null   int64  
 46  customer_gender_male             4009 non-null   int64  
 47  customer_gender_female           4009 non-null   int64  
dtypes: float64(2), int64(46)
memory usage: 1.5 MB
2024-01-27 02:59:27,342 Spinning up shadow endpoint
2024-01-27 02:59:27,342 Creating endpoint-config with name sm-clarify-config-1706324367-f930
2024-01-27 02:59:27,782 Creating endpoint: 'sm-clarify-fraud-detect-xgb-model-1706324367-50db'
2024-01-27 02:59:27,852 No endpoints ruleset found for service sagemaker-internal, falling back to legacy endpoint routing.
2024-01-27 02:59:28,464 Using endpoint name: sm-clarify-fraud-detect-xgb-model-1706324367-50db
2024-01-27 02:59:28,464 Waiting for endpoint ...
2024-01-27 02:59:28,464 Checking endpoint status:
Legend:
(OutOfService: x, Creating: -, Updating: -, InService: !, RollingBack: <, Deleting: o, Failed: *)
2024-01-27 03:04:29,344 Endpoint is in service after 301 seconds
2024-01-27 03:04:29,344 Endpoint ready.
2024-01-27 03:04:29,466 Clarify Kernel SHAP n_coalitions: 500, n_instances: 1, n_features_to_explain: 48, model_output_size: 1
2024-01-27 03:04:29,466 =====================================================
2024-01-27 03:04:29,466 Shap analyzer: explaining 4009 rows, 48 columns...
2024-01-27 03:04:29,466 =====================================================
  0% (0 of 4009) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--
  4% (164 of 4009) |                     | Elapsed Time: 0:00:30 ETA:   0:11:44
  8% (330 of 4009) |#                    | Elapsed Time: 0:01:00 ETA:   0:11:05
 12% (498 of 4009) |##                   | Elapsed Time: 0:01:30 ETA:   0:10:28
 16% (666 of 4009) |###                  | Elapsed Time: 0:02:00 ETA:   0:10:00
 20% (834 of 4009) |####                 | Elapsed Time: 0:02:30 ETA:   0:09:27
 25% (1003 of 4009) |#####               | Elapsed Time: 0:03:00 ETA:   0:08:55
 29% (1172 of 4009) |#####               | Elapsed Time: 0:03:30 ETA:   0:08:26
 33% (1339 of 4009) |######              | Elapsed Time: 0:04:00 ETA:   0:07:59
 37% (1507 of 4009) |#######             | Elapsed Time: 0:04:30 ETA:   0:07:29
 41% (1675 of 4009) |########            | Elapsed Time: 0:05:00 ETA:   0:06:58
 45% (1843 of 4009) |#########           | Elapsed Time: 0:05:31 ETA:   0:06:27
 50% (2011 of 4009) |##########          | Elapsed Time: 0:06:01 ETA:   0:05:58
 54% (2179 of 4009) |##########          | Elapsed Time: 0:06:31 ETA:   0:05:27
 58% (2346 of 4009) |###########         | Elapsed Time: 0:07:01 ETA:   0:04:59
 62% (2514 of 4009) |############        | Elapsed Time: 0:07:31 ETA:   0:04:27
 66% (2682 of 4009) |#############       | Elapsed Time: 0:08:01 ETA:   0:03:57
 71% (2851 of 4009) |##############      | Elapsed Time: 0:08:31 ETA:   0:03:26
 75% (3021 of 4009) |###############     | Elapsed Time: 0:09:01 ETA:   0:02:55
 79% (3191 of 4009) |###############     | Elapsed Time: 0:09:31 ETA:   0:02:24
 83% (3361 of 4009) |################    | Elapsed Time: 0:10:01 ETA:   0:01:54
 88% (3531 of 4009) |#################   | Elapsed Time: 0:10:32 ETA:   0:01:24
 92% (3701 of 4009) |##################  | Elapsed Time: 0:11:02 ETA:   0:00:54
 96% (3871 of 4009) |################### | Elapsed Time: 0:11:32 ETA:   0:00:24
100% (4009 of 4009) |####################| Elapsed Time: 0:11:57 Time:  0:11:57
2024-01-27 03:16:27,709 getting explanations took 718.24 seconds.
2024-01-27 03:16:27,710 ===================================================
2024-01-27 03:16:27,710 Falling back to generic labels: label0, label1, ...
2024-01-27 03:16:28,370 converting explanations to tabular took 0.66 seconds.
2024-01-27 03:16:28,370 ===================================================
2024-01-27 03:16:28,373 Wrote baseline used to compute explanations to: /opt/ml/processing/output/explanations_shap/baseline.csv
2024-01-27 03:16:28,666 Wrote 4009 local explanations to: /opt/ml/processing/output/explanations_shap/out.csv
2024-01-27 03:16:28,667 writing local explanations took 0.30 seconds.
2024-01-27 03:16:28,667 ===================================================
/usr/local/lib/python3.9/site-packages/numpy/core/fromnumeric.py:3430: FutureWarning: In a future version, DataFrame.mean(axis=None) will return a scalar mean over the entire DataFrame. To retain the old behavior, use 'frame.mean(axis=0)' or just 'frame.mean()'
  return mean(axis=axis, dtype=dtype, out=out, **kwargs)
2024-01-27 03:16:28,671 aggregating local explanations took 0.00 seconds.
2024-01-27 03:16:28,671 ===================================================
2024-01-27 03:16:28,671 Shap analysis finished.
2024-01-27 03:16:28,671 Calculated global analysis with predictor
2024-01-27 03:16:28,671 Stop using endpoint: sm-clarify-fraud-detect-xgb-model-1706324367-50db
2024-01-27 03:16:28,671 Deleting endpoint configuration with name: sm-clarify-config-1706324367-f930
2024-01-27 03:16:28,882 Deleting endpoint with name: sm-clarify-fraud-detect-xgb-model-1706324367-50db
2024-01-27 03:16:28,973 Model endpoint delivered 5.57605 requests per second and a total of 4011 requests over 719 seconds
2024-01-27 03:16:28,973 Calculated global analysis without predictor
2024-01-27 03:16:40,349 Stop using endpoint: None
2024-01-27 03:17:39,903 ['jupyter', 'nbconvert', '--to', 'html', '--output', '/opt/ml/processing/output/report.html', '/opt/ml/processing/output/report.ipynb', '--template', 'sagemaker-xai']
[NbConvertApp] Converting notebook /opt/ml/processing/output/report.ipynb to html
[NbConvertApp] Writing 460700 bytes to /opt/ml/processing/output/report.html
2024-01-27 03:17:41,063 ['wkhtmltopdf', '-q', '--enable-local-file-access', '/opt/ml/processing/output/report.html', '/opt/ml/processing/output/report.pdf']
2024-01-27 03:17:41,901 Collected analyses: 
{
    "version": "1.0",
    "explanations": {
        "kernel_shap": {
            "label0": {
                "global_shap_values": {
                    "num_vehicles_involved": 0.0005251420166787212,
                    "num_injuries": 0.005174670053097552,
                    "num_witnesses": 0.0018548867228362655,
                    "police_report_available": 0.00024891863559484075,
                    "injury_claim": 0.00026410604914874856,
                    "vehicle_claim": 0.0011094981993782782,
                    "total_claim_amount": 0.0021939764808579067,
                    "incident_month": 0.0006373220028350442,
                    "incident_day": 0.0007678242695401818,
                    "incident_dow": 0.0003802789458608799,
                    "incident_hour": 0.00043901018113492405,
                    "customer_age": 0.00037774901579428914,
                    "months_as_customer": 0.0019821988426191703,
                    "num_claims_past_year": 0.001544412274358187,
                    "num_insurers_past_5_years": 0.0014499348539781178,
                    "policy_deductable": 0.00024411209658750996,
                    "policy_annual_premium": 0.00022855611973701407,
                    "policy_liability": 0.00023337055882671455,
                    "customer_education": 0.00043961153642958425,
                    "auto_year": 0.00030852583114799557,
                    "driver_relationship_other": 0.00025456116805035916,
                    "driver_relationship_child": 0.00024392144710871797,
                    "driver_relationship_spouse": 0.00023127618259727894,
                    "driver_relationship_na": 0.0002252161638325824,
                    "driver_relationship_self": 0.0004970378721023785,
                    "incident_type_collision": 0.0002469685589565143,
                    "incident_type_break-in": 0.0006203415890202677,
                    "incident_type_theft": 0.00021773268433616808,
                    "collision_type_rear": 0.0002517589966514734,
                    "collision_type_side": 0.0002389066373982229,
                    "collision_type_na": 0.00024258704214020067,
                    "collision_type_front": 0.00024007259560348727,
                    "incident_severity_totaled": 0.00023998259726662626,
                    "incident_severity_major": 0.0002302016370513121,
                    "incident_severity_minor": 0.0002467824110957922,
                    "authorities_contacted_fire": 0.00024380323085285822,
                    "authorities_contacted_none": 0.0004777587237169341,
                    "authorities_contacted_police": 0.0002475484132937249,
                    "authorities_contacted_ambulance": 0.00037626700450755807,
                    "policy_state_ca": 0.0002490558175612668,
                    "policy_state_az": 0.00025195189641015974,
                    "policy_state_nv": 0.00021893194971257866,
                    "policy_state_id": 0.00024787078914821197,
                    "policy_state_wa": 0.0005921100393871569,
                    "policy_state_or": 0.0002422650375930483,
                    "customer_gender_other": 0.00023264450770315047,
                    "customer_gender_male": 0.014314261895997393,
                    "customer_gender_female": 0.00024129718211376972
                },
                "expected_value": 0.01326820533722639
            }
        }
    }
}
2024-01-27 03:17:41,902 exit_message: Completed: SageMaker XAI Analyzer ran successfully
----!

SageMaker Clarifyを使用してモデルの説明可能性解析が実行され、解析結果がファイルとして出力されたことが分かります。

Analysis Config
Clarifyの解析構成が表示されており、データセットのタイプ、ヘッダー情報、ラベル、モデルの構成、解析方法(SHAP)などが指定されています。

Processing Job
Clarifyの解析処理ジョブが正常に作成されています。ジョブはSageMakerのプロセッシングジョブとして実行されます。

Data Loading
入力データがCSV形式で読み込まれ、DataFrameとして処理されています。データには48の特徴量が含まれています。

Endpoint Creation
SHAP解析のためにエンドポイントが作成されています。解析を行うためのモデルがデプロイされています。

SHAP Analysis
SHAP解析が開始され、進捗が表示されています。特徴量ごとの寄与度が計算されています。

Endpoint Termination
解析が終了した後、デプロイされたエンドポイントは停止されています。

Output
SHAP解析の結果、グローバルなSHAP値が表示され、各特徴量の重要度が評価されています。customer_gender_maleが最も寄与が高いようです。

Report Generation
最後に、解析の結果をHTMLとPDFのレポートに変換しています。これにより、結果を視覚的に探索できるようになります。

image.png

SageMaker Clarify の説明可能性レポートを PDF 形式で Amazon S3 から SageMaker Studio のローカルディレクトリにダウンロードします。

動作確認済みコード(2024/1/27)

# Copy explainability report and view
!aws s3 cp s3://{write_bucket}/{write_prefix}/clarify-output/explainability/report.pdf ./clarify_explainability_output.pdf

ローカルディレクトリに clarify_explainability_output.pdf という名前のファイルがコピーされます。

【実行後の画面】
image.png

image.png

出力されたレポート
image.png

image.png

image.png

image.png

image.png

image.png

image.png

【実行後の画面】
image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

SageMaker Clarify が生成する説明可能性レポートには、個々のサンプルのローカル SHAP 値を含む out.csv というファイルも含まれます。以下のコードブロックをコピーして貼り付け、このファイルを使用して、任意の 1 つの例の説明 (各特微量がモデルの予測に与える影響) を可視化することができます。

動作確認済みコード(2024/1/27)

import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
local_explanations_out = pd.read_csv(explainability_report_output_uri + "/explanations_shap/out.csv")
feature_names = [str.replace(c, "_label0", "") for c in 
local_explanations_out.columns.to_series()]
local_explanations_out.columns = feature_names

selected_example = 100
print("Example number:", selected_example)

local_explanations_out.iloc[selected_example].plot(
    kind="bar", title="Local explanation for the example number " + str(selected_example), rot=60, figsize=(20, 8)
);

SHAP(SHapley Additive exPlanations)分析から得られたローカルな説明を可視化しています。具体的には、バー・チャートを使用して、選択した例(selected_exampleで指定される)に対する各特徴量の影響度を示しています。選択した例に対する各特徴量の影響度が視覚化され、その例が予測された結果にどのように影響したかがわかります。

行われている手順は次のとおり

  1. matplotlib ライブラリをインポートし、インライン表示を有効にします。
  2. pd.read_csv() を使用して、SHAP分析からのローカルな説明が含まれるCSVファイルを読み込みます。
  3. 特徴量の名前を取得し、ラベルを整形しています。
  4. 選択した例(selected_example)に対するローカルな説明を取得し、それをバー・チャートでプロットしています。

image.png

image.png

【実行後の画面】
image.png

image.png

ステップ 5: リアルタイム推論エンドポイントにモデルをデプロイする

このステップでは、ハイパーパラメータチューニングジョブから得られた最適なモデルをリアルタイム推論エンドポイントにデプロイし、そのエンドポイントを使用して予測を生成します。トレーニング済みのモデルをデプロイする方法としては、SageMaker SDK、AWS SDK - Boto3、SageMaker コンソールなど、複数の方法があります。詳細については、Amazon SageMaker ドキュメントの推論のためのモデルのデプロイを参照してください。この例では、SageMaker SDK を使用してリアルタイムエンドポイントにモデルをデプロイします。

推論(Inference)は、機械学習モデルがトレーニングされた後、新しいデータに対して予測や分類を行うプロセスです。モデルが学習したパターンや関係性を利用して、未知の入力データに対する出力や判断を行います。推論の主な目的は、モデルが実際の業務環境で有用な予測を行うことです。

手順

  1. 新しいデータの入力 モデルに推論を行いたい新しいデータを提供します。このデータはモデルが学習した特徴と同様の形式である必要があります。

  2. モデルの実行 学習済みモデルに対して、新しいデータを入力して実行(推論)します。モデルはこのデータに基づいて予測や分類を行います。

  3. 出力の取得 モデルの推論が完了すると、その結果として予測されたクラス、値、または確率などの出力が得られます。

推論は、機械学習モデルが実際の問題に対して有用であるかどうかを評価する上で非常に重要です。トレーニングフェーズではモデルがデータから学習することが重視されますが、推論フェーズではその学習結果を元に新しいデータに対して適切な予測をすることが求められます。

image.png

image.png

動作確認済みコード(2024/1/27)

best_train_job_name = tuner.best_training_job()

model_path = estimator_output_uri + '/' + best_train_job_name + '/output/model.tar.gz'
training_image = retrieve(framework="xgboost", region=region, version="1.3-1")
create_model_config = {"model_data":model_path,
                       "role":sagemaker_role,
                       "image_uri":training_image,
                       "name":endpoint_name_prefix,
                       "predictor_cls":sagemaker.predictor.Predictor
                       }
# Create a SageMaker model
model = sagemaker.model.Model(**create_model_config)

# Deploy the best model and get access to a SageMaker Predictor
predictor = model.deploy(initial_instance_count=predictor_instance_count, 
                         instance_type=predictor_instance_type,
                         serializer=CSVSerializer(),
                         deserializer=CSVDeserializer())
print(f"\nModel deployed at endpoint : {model.endpoint_name}")

ハイパーパラメータチューニングジョブから最適なモデルを取得し、そのモデルをSageMakerエンドポイントにデプロイしています。実行すると、ハイパーパラメータチューニングで見つけた最適なモデルがエンドポイントにデプロイされ、新しいデータに対して推論を行うためのエンドポイントが利用可能になります。

best_training_job() メソッドを使用して、ハイパーパラメータチューニングジョブから最適なトレーニングジョブの名前を取得します。

モデルの保存場所 (model_path) を構築します。これはトレーニングジョブの出力ディレクトリ内の model.tar.gz です。

SageMaker Estimatorを使用して、トレーニングしたモデルをSageMakerモデルとして構築します。Predictor クラスを使用することで、推論のための SageMaker ピクチャライザ(predictor)を作成します。

deploy() メソッドを使用して、モデルをエンドポイントにデプロイします。ここで、推論に使用されるインスタンスの数とタイプを指定できます。

エンドポイントの名前が表示されます。

【実行後の画面】
image.png

image.png

image.png

【実行後の画面】
image.png

image.png

image.png

image.png

image.png

image.png

モデルがエンドポイントにデプロイされたので、REST API を直接呼び出す (このチュートリアルでは説明しません)、AWS SDK を使用する、SageMaker Studio のグラフィカルインターフェイスを使用する、または SageMaker Python SDK を使用することで呼び出すことができます。このチュートリアルでは、デプロイステップで利用可能になる SageMaker Predictor を使用して、1 つまたは複数のテストサンプルでリアルタイムのモデル予測を取得します。以下のコードブロックをコピーして貼り付け、エンドポイントを呼び出し、テストデータのサンプルを 1 つ送信します。

image.png

動作確認済みコード(2024/1/27)

# Sample test data
test_df = pd.read_csv(test_data_uri)
payload = test_df.drop(["fraud"], axis=1).iloc[0].to_list()
print(f"Model predicted score : {float(predictor.predict(payload)[0][0]):.3f}, True label : {test_df['fraud'].iloc[0]}")

モデルがトレーニングされた特徴量を使用して新しいデータを評価する手順を示しています。

  1. テストデータ(test_df)を読み込みます。
  2. テストデータからラベル("fraud")を削除し、モデルの入力として使用する特徴量のペイロードを取得します。
  3. predictor.predict() メソッドを使用して、モデルにペイロードを送信し、予測されたスコアを取得します。
  4. 予測されたスコアとテストデータの真のラベルを表示します。

このコードを実行すると、モデルが新しいデータに対してどのように予測するかを確認できます。

【実行後の画面】
image.png

ステップ 6: リソースをクリーンアップする

image.png

# Delete model
try:
 sess.delete_model(xgb_model_name)
except:
 pass
sess.delete_model(model.name)

# Delete inference endpoint config
sess.delete_endpoint_config(endpoint_config_name=predictor._get_endpoint_config_name())

# Delete inference endpoint
sess.delete_endpoint(endpoint_name=model.endpoint_name)

SageMaker でデプロイされたモデルと関連するエンドポイントを削除するためのものです。

  1. 最初に、xgb_model_name という名前のモデルを sess インスタンスを使用して削除します。delete_model メソッドが呼び出される前に、try-except ブロックがあります。これは、xgb_model_name が存在しない場合に例外が発生しないようにするためのものです。

  2. 次に、model.name という名前のモデルを sess インスタンスを使用して削除します。このモデルは先にデプロイされています。

  3. モデルを削除した後、そのモデルに関連するエンドポイント構成(endpoint_config_name)を delete_endpoint_config メソッドを使用して削除します。

  4. 最後に、model.endpoint_name という名前のエンドポイントを delete_endpoint メソッドを使用して削除します。これにより、エンドポイントが停止され、使用されなくなります。

これらのステップを通じて、SageMaker でデプロイされたモデルをクリーンアップし、関連するリソースを削除しています。

image.png

image.png

Sagemaker のkernelを停止させ、ユーザプロファイル、ドメイン、s3リソースを削除する必要があります。

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?