SHAP値でのエラー
解決したいこと
ランダムフォレストを用いてSHAP値を求めたが、予想される結果がでなかった
考えられる問題について教えて頂きたいです。
発生している問題・エラー
本来であればCODの赤い点がマイナスに集中してほしいが、実際は逆になってしまう。なぜだろうか。
ソースコード
下記のコードでは、クラスターA~Cの3つデータセットが用意されていて、クラスターごとにランダムフォレストを用いたSHAP値を求めている。今回、クラスターCで問題が出てしまった。
import os
# Ensure SHAP uses CPU only
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["USE_CUDA"] = "0"
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
import shap
import matplotlib.pyplot as plt
from matplotlib import font_manager
from sklearn.metrics import r2_score
#── 日本語フォント設定 ─────────────────────────────────
#jfnt = font_manager.FontProperties(fname=r"C:\Windows\Fonts\meiryo.ttc")
import matplotlib as mpl
mpl.rcParams['font.family'] = 'Meiryo'
mpl.rcParams['font.size'] = 12
mpl.rcParams['axes.unicode_minus'] = False
^
#── ファイルパス ────────────────────────────────────────
lake_path = r"C:\usr\yoshikawa\DO\DO_統合_環境白書追加.xlsx"
cluster_path = r"C:\usr\yoshikawa\DO\DO_クラスタ結果_回復値付き.xlsx"
output_dir = r"C:\usr\yoshikawa\DO\rf_tuned_shap4"
os.makedirs(output_dir, exist_ok=True)
#── データ読み込み・年度列追加 ──────────────────────────────────
usecols = [
"採水月日",
"Wind Speed (m/s)",
"Precipitation (mm)",
" Temperature (°C)",
"COD (mg/l)",
"phytoplankton(0.5)(μ/l)",
"N(mg/l)",
"PO4-P (mg/l)",
"DO (mg/l)",
"Nutrient Salts",
"SSI(J/m²)",
"NH4-N (mg/l)",
"NO2-N (mg/l)",
"TOC (mg/l)",
"NO3-N (mg/l)",
"N(mg/l)"
]
df_lake = pd.read_excel(lake_path,
sheet_name="Sheet1",
usecols=usecols,
engine="openpyxl")
df_lake["採水月日"] = pd.to_datetime(df_lake["採水月日"])
df_lake["年度"] = np.where(
df_lake["採水月日"].dt.month >= 4,
df_lake["採水月日"].dt.year,
df_lake["採水月日"].dt.year - 1
)
#── クラスタ結果読み込み&マージ ────────────────────────
df_cl = pd.read_excel(cluster_path,
usecols=["年度", "Cluster"],
engine="openpyxl",
sheet_name="Sheet3", )
df = pd.merge(df_lake, df_cl, on="年度", how="inner").reset_index(drop=True)
#── 特徴量 & 目的変数設定 ────────────────────────────────
features = [
# "Wind Speed (m/s)",
# "Precipitation (mm)",
# " Temperature (°C)",
# "COD (mg/l)",
# "phytoplankton(0.5)(μ/l)",
# "N(mg/l)",
# "PO4-P (mg/l)",
# "Nutrient Salts",
"SSI(J/m²)",
"NH4-N (mg/l)",
"COD (mg/l)"
]
X = df[features].values
y = df["DO (mg/l)"].values
# 2002年4月以降にデータを絞る
df_lake = df_lake[df_lake["採水月日"] >= pd.Timestamp(2002, 4, 1)].copy()
#── データ標準化 ───────────────────────────────────
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
#── ハイパーパラメータ探索設定 ─────────────────────────
param_grid = {
'n_estimators': [100, 200],
'max_depth': [None, 10, 20],
'min_samples_split': [2, 5],
'min_samples_leaf': [1, 2],
'max_features': ['sqrt', 'log2']
}
#── クラスタ毎にチューニング+SHAP解析+プロット出力 ─────────────────
for cl in sorted(df["Cluster"].unique()):
mask = (df["Cluster"] == cl)
Xc, yc = X_scaled[mask], y[mask]
# 1) 最適RFモデル探索
rf = RandomForestRegressor(random_state=0)
grid = GridSearchCV(rf,
param_grid,
cv=3,
scoring='neg_mean_squared_error',
n_jobs=-1)
grid.fit(Xc, yc)
best_rf = grid.best_estimator_
# 2) 決定係数 (R²) の計算
y_pred = best_rf.predict(Xc)
r2 = r2_score(yc, y_pred)
print(f"[Cluster {cl}] R² = {r2:.3f}")
# 2) SHAP値算出 (符号付き)
explainer = shap.TreeExplainer(best_rf,
feature_perturbation="tree_path_dependent")
shap_vals = explainer.shap_values(Xc) # shape = (n_samples, n_features)
# A) Summary Plot (ドットプロット)
plt.figure(figsize=(8, 6))
shap.summary_plot(
shap_vals,
Xc,
feature_names=features,
plot_type="dot",
show=False,
color_bar_label="Feature value"
)
plt.title(f"Cluster {cl} SHAP Summary", fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f"cluster_{cl}_shap_summary.png"), dpi=150)
plt.close()
# B) 平均 |SHAP| の棒グラフ
shap_abs_mean = np.abs(shap_vals).mean(axis=0)
fig1, ax1 = plt.subplots(figsize=(8, 5))
y_pos = np.arange(len(features))
ax1.barh(y_pos, shap_abs_mean, color='orchid')
ax1.set_yticks(y_pos)
ax1.set_yticklabels(features, fontsize=12)
ax1.invert_yaxis()
ax1.set_xlim(0, shap_abs_mean.max() * 1.1)
ax1.set_xlabel("平均 |SHAP値|", fontsize=12)
ax1.set_title(f"Cluster {cl} Mean |SHAP|", fontsize=14)
ax1.grid(axis="x", linestyle="--", alpha=0.3)
for i, v in enumerate(shap_abs_mean):
ax1.text(v + shap_abs_mean.max()*0.01, i, f"{v:.3f}",
fontsize=10, va="center")
fig1.tight_layout()
ax1.set_xlim(0, 2)
fig1.savefig(os.path.join(output_dir, f"cluster_{cl}_shap_bar_abs.png"), dpi=150)
plt.close(fig1)
# ベストパラメータ保存
with open(os.path.join(output_dir, f"cluster_{cl}_best_params.txt"), 'w') as f:
f.write(str(grid.best_params_))
print(f"[Cluster {cl}] SHAP plots saved; best_params = {grid.best_params_}")
#── 全データで同様の解析 ───────────────────────────────────
# 1) 最適モデル探索
rf_all = RandomForestRegressor(random_state=0)
grid_all = GridSearchCV(rf_all, param_grid, cv=3,
scoring='neg_mean_squared_error', n_jobs=-1)
grid_all.fit(X_scaled, y)
best_rf_all = grid_all.best_estimator_
# 2) 全データ R²
y_all_pred = best_rf_all.predict(X_scaled)
r2_all = r2_score(y, y_all_pred)
print(f"[All Data] R² = {r2_all:.3f}")
with open(os.path.join(output_dir, "all_r2.txt"), 'w', encoding='utf-8') as f:
f.write(f"R² = {r2_all:.3f}\n")
# 3) SHAP 値計算
explainer_all = shap.TreeExplainer(best_rf_all, feature_perturbation="tree_path_dependent")
shap_vals_all = explainer_all.shap_values(X_scaled)
# A) Summary Plot
plt.figure(figsize=(8, 6))
shap.summary_plot(
shap_vals_all, X_scaled,
feature_names=features,
plot_type="dot",
show=False,
color_bar_label="Feature value"
)
plt.title("All Data SHAP サマリ", fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "all_shap_summary.png"), dpi=150)
plt.close()
# B) 平均 |SHAP| 棒グラフ
shap_abs_mean_all = np.abs(shap_vals_all).mean(axis=0)
fig3, ax3 = plt.subplots(figsize=(8, 5))
y_pos = np.arange(len(features))
ax3.barh(y_pos, shap_abs_mean_all, color='teal')
ax3.set_yticks(y_pos)
ax3.set_yticklabels(features)
ax3.invert_yaxis()
ax3.set_xlim(0, shap_abs_mean_all.max() * 1.1)
ax3.set_xlabel("平均 |SHAP値|")
ax3.set_title("All Data 平均 |SHAP|", fontsize=14)
ax3.grid(axis="x", linestyle="--", alpha=0.3)
for i, v in enumerate(shap_abs_mean_all):
ax3.text(v + shap_abs_mean_all.max()*0.01, i, f"{v:.3f}", va="center")
fig3.tight_layout()
ax3.set_xlim(0, 2)
fig3.savefig(os.path.join(output_dir, "all_shap_bar_abs.png"), dpi=150)
plt.close(fig3)
# ベストパラメータ(全データ)保存
with open(os.path.join(output_dir, "all_best_params.txt"), 'w', encoding='utf-8') as f:
f.write(str(grid_all.best_params_))
print("[All Data] SHAP プロット保存完了; best_params =", grid_all.best_params_)
自分で試したこと
説明変数を何パターンか取捨選択して結果を出力した。
何か気が付きましたら教えていただけますと幸いです。
0 likes