背景
こちらの記事にて、SageMaker ClarifyでPDPレポートを作成してみた時に、他にSHAP値も出力出来るようでしたので、今回の記事では、SHAP値レポートの作成を試してみました。
環境
sagemaker:2.219.0
試した事(概要)
Nishikaのコンペサイトのデータを使って、SageMakerの組み込みアルゴリズムのXGBoostで予測を行い、その予測に対してSageMaker ClarifyでSHAP値レポートを作成してみました。
試した事(詳細)
1. 準備
データの取得、前処理、およびXGBoostモデルの学習は、こちらの記事のものを流用しました。
具体的には、記事の2.2.までは同じになります。
2. 実装
2.1. Clarifyを行うためのモデルを作成
学習を行ってS3に作成されたmodel.tar.gzファイルを使って、Modelクラスをインスタンス化します。
training_job_name = "test-2024714111453"
xgboost_model = sagemaker.model.Model(image_uri=xgboost_container,
model_data="{a}/{b}/output/model.tar.gz".format(a=s3_train_output_path,
b=training_job_name),
role=ROLE,
predictor_cls=sagemaker.predictor.RealTimePredictor)
インスタンス化したModelクラスのオブジェクトを使って、Clarify用のモデルを作成します。
model_name_for_clarify_SHAP = "test-xgboost-model-for-clarify-SHAP"
container_for_clarify_SHAP = xgboost_model.prepare_container_def(instance_type="ml.m5.large")
SAGEMAKER_SESSION.create_model(name=model_name_for_clarify_SHAP,
role=ROLE,
container_defs=container_for_clarify_SHAP)
これでClarify用のモデルが作成されました。
マネジメントコンソールのSageMakerの画面でもモデルを確認出来ます。
2.2. Clarifyでレポートを作成するための準備
S3に保存したsimple_valid_for_clarify.csvを使って、作成したXGBoostモデルの予測に対するレポートをClarifyで作成するための準備を行います。
まずは、simple_valid_for_clarify.csvのS3パスを準備します。
s3_clarify_SHAP_report_output_path = "s3://" + os.path.join(S3_BUCKET, S3_PREFIX) + "/clarify_SHAP_report"
VALID_FOR_CLARIFY_DATA_PATH = "simple_valid_for_clarify.csv"
s3_valid_for_clarify_data_file_path = "s3://" + os.path.join(S3_BUCKET, S3_PREFIX, VALID_FOR_CLARIFY_DATA_PATH)
ClarifyでPDP&SHAP値レポートを作成してみます。
SHAP値のみでなく、PDPと一緒のレポートになるみたいです。
PDP&SHAP値レポートを作成するためには、5つの設定が必要になります。
1つ目はDataConfigになります。
引数のheadersには、CSVファイル化の際に削除したカラム名を設定する形になります。PDP&SHAP値レポートで文字化けしないように、英語で設定します。
引数のlabelには、目的変数のカラム名を設定します。
clarify_SHAP_report_data_config = sagemaker.clarify.DataConfig(s3_data_input_path=s3_valid_for_clarify_data_file_path,
s3_output_path=s3_clarify_SHAP_report_output_path,
label="predicted_price",
headers=["predicted_price", "prefecture_code", "distance_from_station", "area", "built_area", "built_volumn"],
dataset_type="text/csv")
2つ目はModelConfigになります。
引数のmodel_nameには、先程作成したClarify用のモデルの名前を設定します。
(今回の場合は、model_name_for_clarify_SHAP = "test-xgboost-model-for-clarify-SHAP"の名前になります。)
clarify_SHAP_report_model_config = sagemaker.clarify.ModelConfig(model_name=model_name_for_clarify_SHAP,
instance_count=1,
instance_type="ml.m5.large",
accept_type="text/csv",
content_type="text/csv")
3つ目はModelPredictedLabelConfigになります。
設定する引数は、回帰タスク/二値分類タスク/多値分類タスク毎に異なるみたいでして、今回の回帰タスクでは特に引数は設定しない形のようです。
>Regression task: The model returns the score, e.g. 1.2. We don’t need to specify anything.
clarify_SHAP_report_predict_config = sagemaker.clarify.ModelPredictedLabelConfig()
4つ目はPDPConfigになります。
引数のfeaturesには、説明変数のカラム名を設定します。
(DataConfigの引数headersは全カラム名、DataConfigの引数labelは目的変数のカラム名、PDPConfigの引数featuresには説明変数のカラム名、を設定するイメージになります。)
clarify_SHAP_report_pdp_config = sagemaker.clarify.PDPConfig(features=["prefecture_code", "distance_from_station", "area", "built_area", "built_volumn"],
grid_resolution=15)
5つ目はSHAPConfigになります。
# S3に保存されているsimple_valid_for_clarify.csvをNotebookインスタンスのローカルにダウンロード
boto3.Session().resource("s3").Bucket(S3_BUCKET).Object(os.path.join(S3_PREFIX, VALID_FOR_CLARIFY_DATA_PATH)).download_file("./data/Nishika_ApartmentPrice/simple_valid_for_clarify.csv")
# ダウンロードしたcsvファイルをDataFrame化(カラム名は削除されているので、引数headerはNone)
s3_valid_for_clarify_data_df = pd.read_csv(filepath_or_buffer="./data/Nishika_ApartmentPrice/simple_valid_for_clarify.csv",
header=None)
# 説明変数のカラムの値をリストで取得(説明変数は2列目以降)
s3_valid_for_clarify_data_value_list = s3_valid_for_clarify_data_df.iloc[:, 1:].values.tolist()
clarify_SHAP_report_config = sagemaker.clarify.SHAPConfig(baseline=s3_valid_for_clarify_data_value_list,
num_samples=s3_valid_for_clarify_data_df.shape[0],
agg_method="mean_abs")
2.3. Clarifyでレポートを作成
まずは、Processorを作成します。
Clarifyでレポートを作成する時は、このProcessorを使う形になります。
clarify_processor = sagemaker.clarify.SageMakerClarifyProcessor(role=ROLE,
instance_count=1,
instance_type="ml.m5.large",
max_runtime_in_seconds=3600,
sagemaker_session=SAGEMAKER_SESSION)
作成したProcessorのrun_explainabilityメソッドを使って、ClarifyのPDP&SHAP値レポートを作成します。
引数には、先程準備したDataConfig、ModelConfig、ModelPredictedLabelConfig、PDPConfig、SHAPConfigの5つを設定します。
clarify_processor.run_explainability(data_config=clarify_SHAP_report_data_config,
model_config=clarify_SHAP_report_model_config,
explainability_config=[clarify_SHAP_report_pdp_config, clarify_SHAP_report_config],
model_scores=clarify_SHAP_report_predict_config,
wait=True,
logs=True)
INFO:sagemaker:Creating processing-job with name Clarify-Explainability-2024-07-25-02-55-38-816
......................WARNING:root:logging.conf not found when configuring logging, using default logging configuration.
INFO:sagemaker-clarify-processing:Starting SageMaker Clarify Processing job
INFO:analyzer.data_loading.data_loader_util:Analysis config path: /opt/ml/processing/input/config/analysis_config.json
INFO:analyzer.data_loading.data_loader_util:Analysis result path: /opt/ml/processing/output
[NbConvertApp] Converting notebook /opt/ml/processing/output/report.ipynb to html
[NbConvertApp] Writing 846474 bytes to /opt/ml/processing/output/report.html
INFO:analyzer.utils.util:['wkhtmltopdf', '-q', '--enable-local-file-access', '/opt/ml/processing/output/report.html', '/opt/ml/processing/output/report.pdf']
INFO:analyzer.utils.system_util:exit_message: Completed: SageMaker XAI Analyzer ran successfully
INFO:py4j.clientserver:Closing down clientserver connection
---!
DataConfigの引数s3_output_pathで設定したS3フォルダ内に、複数のファイルが作成されました。
2.4. レポートを確認
作成されたファイルのreport.pdfをローカルにダウンロードして開いてみます。
説明変数毎の貢献度や分布を確認する事が出来ました。
今回は以上になります。
まとめ
SageMaker ClarifyのレポートにSHAP値も追加する事が出来ました。SHAP値は「それぞれのデータ毎の予測結果に対して、予測結果の平均とのズレは、どの説明変数がどれ程寄与しているかを定量的に示したもの」というレベルの理解に留まっていますため、レポートをキチンと読み解けていませんが、予測モデルの解釈性の観点からも、SHAP値をキチンと理解しようと思いました。
参考