0
0

【AWS SageMaker Studio】SageMaker ClarifyでPDP&SHAP値レポートの作成を試してみた

Posted at

背景

こちらの記事にて、SageMaker ClarifyでPDPレポートを作成してみた時に、他にSHAP値も出力出来るようでしたので、今回の記事では、SHAP値レポートの作成を試してみました。

環境

sagemaker:2.219.0

試した事(概要)

Nishikaのコンペサイトのデータを使って、SageMakerの組み込みアルゴリズムのXGBoostで予測を行い、その予測に対してSageMaker ClarifyでSHAP値レポートを作成してみました。

試した事(詳細)

1. 準備

データの取得、前処理、およびXGBoostモデルの学習は、こちらの記事のものを流用しました。

具体的には、記事の2.2.までは同じになります。

2. 実装

2.1. Clarifyを行うためのモデルを作成

学習を行ってS3に作成されたmodel.tar.gzファイルを使って、Modelクラスをインスタンス化します。

hogehoge.ipynb
training_job_name = "test-2024714111453"
xgboost_model = sagemaker.model.Model(image_uri=xgboost_container,
                                      model_data="{a}/{b}/output/model.tar.gz".format(a=s3_train_output_path,
                                                                                      b=training_job_name),
                                      role=ROLE,
                                      predictor_cls=sagemaker.predictor.RealTimePredictor)

インスタンス化したModelクラスのオブジェクトを使って、Clarify用のモデルを作成します。

hogehoge.ipynb
model_name_for_clarify_SHAP = "test-xgboost-model-for-clarify-SHAP"
container_for_clarify_SHAP = xgboost_model.prepare_container_def(instance_type="ml.m5.large")
SAGEMAKER_SESSION.create_model(name=model_name_for_clarify_SHAP,
                               role=ROLE,
                               container_defs=container_for_clarify_SHAP)

これでClarify用のモデルが作成されました。
マネジメントコンソールのSageMakerの画面でもモデルを確認出来ます。

スクショ1.PNG

2.2. Clarifyでレポートを作成するための準備

S3に保存したsimple_valid_for_clarify.csvを使って、作成したXGBoostモデルの予測に対するレポートをClarifyで作成するための準備を行います。
まずは、simple_valid_for_clarify.csvのS3パスを準備します。

hogehoge.ipynb
s3_clarify_SHAP_report_output_path = "s3://" + os.path.join(S3_BUCKET, S3_PREFIX) + "/clarify_SHAP_report"
VALID_FOR_CLARIFY_DATA_PATH = "simple_valid_for_clarify.csv"
s3_valid_for_clarify_data_file_path = "s3://" + os.path.join(S3_BUCKET, S3_PREFIX, VALID_FOR_CLARIFY_DATA_PATH)

ClarifyでPDP&SHAP値レポートを作成してみます。
SHAP値のみでなく、PDPと一緒のレポートになるみたいです。
PDP&SHAP値レポートを作成するためには、5つの設定が必要になります。
1つ目はDataConfigになります。
引数のheadersには、CSVファイル化の際に削除したカラム名を設定する形になります。PDP&SHAP値レポートで文字化けしないように、英語で設定します。
引数のlabelには、目的変数のカラム名を設定します。

hogehoge.ipynb
clarify_SHAP_report_data_config = sagemaker.clarify.DataConfig(s3_data_input_path=s3_valid_for_clarify_data_file_path,
                                                               s3_output_path=s3_clarify_SHAP_report_output_path,
                                                               label="predicted_price",
                                                               headers=["predicted_price", "prefecture_code", "distance_from_station", "area", "built_area", "built_volumn"],
                                                               dataset_type="text/csv")

2つ目はModelConfigになります。
引数のmodel_nameには、先程作成したClarify用のモデルの名前を設定します。
(今回の場合は、model_name_for_clarify_SHAP = "test-xgboost-model-for-clarify-SHAP"の名前になります。)

hogehoge.ipynb
clarify_SHAP_report_model_config = sagemaker.clarify.ModelConfig(model_name=model_name_for_clarify_SHAP,
                                                                 instance_count=1,
                                                                 instance_type="ml.m5.large",
                                                                 accept_type="text/csv",
                                                                 content_type="text/csv")

3つ目はModelPredictedLabelConfigになります。
設定する引数は、回帰タスク/二値分類タスク/多値分類タスク毎に異なるみたいでして、今回の回帰タスクでは特に引数は設定しない形のようです。

>Regression task: The model returns the score, e.g. 1.2. We don’t need to specify anything.

hogehoge.ipynb
clarify_SHAP_report_predict_config = sagemaker.clarify.ModelPredictedLabelConfig()

4つ目はPDPConfigになります。
引数のfeaturesには、説明変数のカラム名を設定します。
(DataConfigの引数headersは全カラム名、DataConfigの引数labelは目的変数のカラム名、PDPConfigの引数featuresには説明変数のカラム名、を設定するイメージになります。)

hogehoge.ipynb
clarify_SHAP_report_pdp_config = sagemaker.clarify.PDPConfig(features=["prefecture_code", "distance_from_station", "area", "built_area", "built_volumn"],
                                                             grid_resolution=15)

5つ目はSHAPConfigになります。

hogehoge.ipynb
# S3に保存されているsimple_valid_for_clarify.csvをNotebookインスタンスのローカルにダウンロード
boto3.Session().resource("s3").Bucket(S3_BUCKET).Object(os.path.join(S3_PREFIX, VALID_FOR_CLARIFY_DATA_PATH)).download_file("./data/Nishika_ApartmentPrice/simple_valid_for_clarify.csv")
# ダウンロードしたcsvファイルをDataFrame化(カラム名は削除されているので、引数headerはNone)
s3_valid_for_clarify_data_df = pd.read_csv(filepath_or_buffer="./data/Nishika_ApartmentPrice/simple_valid_for_clarify.csv",
                                           header=None)
# 説明変数のカラムの値をリストで取得(説明変数は2列目以降)
s3_valid_for_clarify_data_value_list = s3_valid_for_clarify_data_df.iloc[:, 1:].values.tolist()
hogehoge.ipynb
clarify_SHAP_report_config = sagemaker.clarify.SHAPConfig(baseline=s3_valid_for_clarify_data_value_list,
                                                          num_samples=s3_valid_for_clarify_data_df.shape[0],
                                                          agg_method="mean_abs")
2.3. Clarifyでレポートを作成

まずは、Processorを作成します。
Clarifyでレポートを作成する時は、このProcessorを使う形になります。

hogehoge.ipynb
clarify_processor = sagemaker.clarify.SageMakerClarifyProcessor(role=ROLE,
                                                                instance_count=1,
                                                                instance_type="ml.m5.large",
                                                                max_runtime_in_seconds=3600,
                                                                sagemaker_session=SAGEMAKER_SESSION)

作成したProcessorのrun_explainabilityメソッドを使って、ClarifyのPDP&SHAP値レポートを作成します。
引数には、先程準備したDataConfig、ModelConfig、ModelPredictedLabelConfig、PDPConfig、SHAPConfigの5つを設定します。

hogehoge.ipynb
clarify_processor.run_explainability(data_config=clarify_SHAP_report_data_config,
                                     model_config=clarify_SHAP_report_model_config,
                                     explainability_config=[clarify_SHAP_report_pdp_config, clarify_SHAP_report_config],
                                     model_scores=clarify_SHAP_report_predict_config,
                                     wait=True,
                                     logs=True)
INFO:sagemaker:Creating processing-job with name Clarify-Explainability-2024-07-25-02-55-38-816
......................WARNING:root:logging.conf not found when configuring logging, using default logging configuration.
INFO:sagemaker-clarify-processing:Starting SageMaker Clarify Processing job
INFO:analyzer.data_loading.data_loader_util:Analysis config path: /opt/ml/processing/input/config/analysis_config.json
INFO:analyzer.data_loading.data_loader_util:Analysis result path: /opt/ml/processing/output

[NbConvertApp] Converting notebook /opt/ml/processing/output/report.ipynb to html
[NbConvertApp] Writing 846474 bytes to /opt/ml/processing/output/report.html
INFO:analyzer.utils.util:['wkhtmltopdf', '-q', '--enable-local-file-access', '/opt/ml/processing/output/report.html', '/opt/ml/processing/output/report.pdf']
INFO:analyzer.utils.system_util:exit_message: Completed: SageMaker XAI Analyzer ran successfully
INFO:py4j.clientserver:Closing down clientserver connection
---!

DataConfigの引数s3_output_pathで設定したS3フォルダ内に、複数のファイルが作成されました。

スクショ2.PNG

スクショ3a.PNG

2.4. レポートを確認

作成されたファイルのreport.pdfをローカルにダウンロードして開いてみます。

スクショ4.PNG

スクショ5.PNG

説明変数毎の貢献度や分布を確認する事が出来ました。


今回は以上になります。

まとめ

SageMaker ClarifyのレポートにSHAP値も追加する事が出来ました。SHAP値は「それぞれのデータ毎の予測結果に対して、予測結果の平均とのズレは、どの説明変数がどれ程寄与しているかを定量的に示したもの」というレベルの理解に留まっていますため、レポートをキチンと読み解けていませんが、予測モデルの解釈性の観点からも、SHAP値をキチンと理解しようと思いました。

参考

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0