LoginSignup
191
164

More than 1 year has passed since last update.

SHAPで因果関係を説明できる?

Last updated at Posted at 2023-02-25

はじめに

予測モデル(機械学習モデル)を解釈するのに有用なSHAPを用いて因果関係を説明することができるか、についてPythonによるシミュレーションを交えてまとめました。内容に誤り等ございましたら、ご指摘いただけますと幸いです。

結論

基本的に、SHAPで因果関係は説明できません。これは、SHAPが予測モデルの因果ではなく相関を明らかにするものであるからです。

そこで今回は、予測モデルをSHAPで解釈する上でありがちなミスリーディングや、それに関連する因果効果を推定するためのアプローチについて記載しています。

そもそもSHAPとは

SHAPとはSHapley Additive exPlanationsの略で、協力ゲーム理論のShapley Valueを機械学習に応用した手法です。「その予測モデルがなぜ、その予測値を算出しているか」を解釈するためのツールとしてオープンソースのライブラリが開発されており、Python等で簡単に実行することができます。

本記事では、協力ゲーム理論やShapley Value、SHAPのアルゴリズムに関する説明はしません。それらについて知りたい方は、下記の文献などをご参照ください。

協力ゲーム理論(Shapley Value)に関する参考文献

SHAPに関する参考文献

Pythonによるシミュレーション

それでは、予測モデル(機械学習モデル)にSHAPを用いた結果が因果関係とはならないようなケースについて、Pythonを用いてシミュレーションしていきます。

データ生成

SHAPの公式ドキュメントに従って、データを生成します。

# 必要なライブラリをインポート
import numpy as np
import pandas as pd
import scipy.stats
import sklearn
import xgboost
import shap

# データを生成するためのクラスや関数を作成
class FixableDataFrame(pd.DataFrame):
    """ Helper class for manipulating generative models.
    """
    def __init__(self, *args, fixed={}, **kwargs):
        self.__dict__["__fixed_var_dictionary"] = fixed
        super(FixableDataFrame, self).__init__(*args, **kwargs)
    def __setitem__(self, key, value):
        out = super(FixableDataFrame, self).__setitem__(key, value)
        if isinstance(key, str) and key in self.__dict__["__fixed_var_dictionary"]:
            out = super(FixableDataFrame, self).__setitem__(key, self.__dict__["__fixed_var_dictionary"][key])
        return out

def generator(n, fixed={}, seed=0):
    """ The generative model for our subscriber retention example.
    """
    if seed is not None:
        np.random.seed(seed)
    X = FixableDataFrame(fixed=fixed)

    # 訪問販売の回数
    X["Sales calls"] = np.random.uniform(0, 4, size=(n,)).round()

    # 交流回数
    X["Interactions"] = X["Sales calls"] + np.random.poisson(0.2, size=(n,))

    # 消費者が住む地域の経済状況
    X["Economy"] = np.random.uniform(0, 1, size=(n,))

    # 最終更新時期
    X["Last upgrade"] = np.random.uniform(0, 20, size=(n,))

    # プロダクトのニーズ(未観測)
    X["Product need"] = (X["Sales calls"] * 0.1 + np.random.normal(0, 1, size=(n,)))

    # 割引額
    X["Discount"] = ((1-scipy.special.expit(X["Product need"])) * 0.5 + 0.5 * np.random.uniform(0, 1, size=(n,))) / 2

    # ユーザーの利用頻度
    X["Monthly usage"] = scipy.special.expit(X["Product need"] * 0.3 + np.random.normal(0, 1, size=(n,)))

    # 広告費用
    X["Ad spend"] = X["Monthly usage"] * np.random.uniform(0.99, 0.9, size=(n,)) + (X["Last upgrade"] < 1) + (X["Last upgrade"] < 2)

    # ユーザーのバグ遭遇回数(未観測)
    X["Bugs faced"] = np.array([np.random.poisson(v*2) for v in X["Monthly usage"]])

    # ユーザーのバグ報告回数
    X["Bugs reported"] = (X["Bugs faced"] * scipy.special.expit(X["Product need"])).round()

    # ユーザーがサブスクを更新する確率
    X["Did renew"] = scipy.special.expit(7 * (
          0.18 * X["Product need"] \
        + 0.08 * X["Monthly usage"] \
        + 0.1 * X["Economy"] \
        + 0.05 * X["Discount"] \
        + 0.05 * np.random.normal(0, 1, size=(n,)) \
        + 0.05 * (1 - X['Bugs faced'] / 20) \
        + 0.005 * X["Sales calls"] \
        + 0.015 * X["Interactions"] \
        + 0.1 / (X["Last upgrade"]/4 + 0.25)
        + X["Ad spend"] * 0.0 - 0.45
    ))
    # ユーザーがサブスクを更新したかどうか(更新していたら1)
    X["Did renew"] = scipy.stats.bernoulli.rvs(X["Did renew"])

    return X

def user_retention_dataset():
    """ The observed data for model training.
    """
    n = 10000
    X_full = generator(n)
    y = X_full["Did renew"]
    X = X_full.drop(["Did renew", "Product need", "Bugs faced"], axis=1)
    return X, y

def fit_xgboost(X, y):
    """ Train an XGBoost model with early stopping.
    """
    X_train,X_test,y_train,y_test = sklearn.model_selection.train_test_split(X, y)
    dtrain = xgboost.DMatrix(X_train, label=y_train)
    dtest = xgboost.DMatrix(X_test, label=y_test)
    model = xgboost.train(
        { "eta": 0.001, "subsample": 0.5, "max_depth": 2, "objective": "reg:logistic"}, dtrain, num_boost_round=200000,
        evals=((dtest, "test"),), early_stopping_rounds=20, verbose_eval=False
    )
    return model

# データを格納
X, y = user_retention_dataset()

# 値の確認
display(pd.DataFrame(X).head())
display(pd.DataFrame(y).head())

(出力結果)
スクリーンショット 2023-02-24 23.55.51.png

設定

サブスク購入者が商品のサブスク購入を更新するかどうか(Did renew = 1 or 0)を予測するモデルを考えます。

先ほどPythonで生成したデータのうち、下記のデータが観測でき、モデルの特徴量$\ X \ $として利用します。

特徴量 概要
Sales calls 訪問販売の回数
Interactions 交流回数
Economy 居住地域の経済状況
Last upgrade 最終更新日
Discount 割引額
Monthly usage 月の利用度合い
Ad spend 広告費用
Bugs reported バグの報告回数

一方で、先ほどPythonで生成したデータのうち、下記の(交絡)変数は未観測とします。

未観測の交絡 概要
Product need プロダクトのニーズ
Bugs faced バグの遭遇回数

データの生成過程から、各変数の関係性は下図の通りになっています。破線部は未観測である変数を表しており、グレーで塗られている変数(Did renew)は目的変数を表しています。
スクリーンショット 2023-02-25 2.30.22.png

今回はデータの生成過程が分かっているため簡単に変数間の関係性を記述することができますが、現実世界における分析では変数間の関係性を記述することは容易では有りません。ドメイン知識や因果探索の手法に基づいて、変数間の関係性を考えていく必要があります。

ありがちなミスリーディング

XGBoostでサブスク購入を更新するかどうかを予測するモデルを作成し、SHAPで予測モデルにおける各特徴量の重要度を表す棒グラフをプロットします。

# 予測モデルを学習
model = fit_xgboost(X, y)

# SHAP
explainer = shap.Explainer(model)
shap_values = explainer(X)

clust = shap.utils.hclust(X, y, linkage="complete")

# 予測モデルにおける特徴量の重要度をプロット
shap.plots.bar(shap_values, clustering=clust, clustering_cutoff=1)

(出力結果)
download.png

棒グラフの長さがを予測モデルへの影響度(絶対値)を表しており、影響度の大きい特徴量Top3は大きい方から順番に

  • Discount: 割引額
  • Ad spend: 広告費用
  • Bugs reported: バグ報告回数

であることが分かります。この結果から

  • 割引額が大きいほど、更新されやすい
  • 広告費用をかけるほど、更新されやすい
  • バグ報告回数が多いほど、更新されにくい

と解釈したら、一見それっぽいように思えますよね?

しかし、もう少し深掘りしてみると、直感に反する傾向が見て取れます。それを確かめるべく、SHAPで散布図を描画してみます。

# 散布図を描画
shap.plots.scatter(shap_values)

(出力結果)
download-1.png

SHAPの散布図は「特徴量の値」が「モデルの予測(更新する確率)」に対してどのように変化を与えるかを表しています。例えば、青いプロットが右肩上がりであれば「その特徴量が大きいほど、更新する確率も大きい」と解釈できるのです。

先ほど棒グラフを描画した際に、モデル内での重要度が高いと出力された特徴量「Discount」の散布図を見てみるとどうでしょう。

download-2.png

「Discount(割引額)」の散布図は、右肩下がりになっています。これをそのまま解釈すると「割引すればするほど、更新する確率は低くなる」となります。これは、とても直感に反するかと思われます。

なぜ、このような直感と異なる結果となってしまったのでしょうか?それは、「Discount」と「Did renew」の両方に影響を与える交絡が存在しているからです。交絡を確認するために変数同士の関係性は再掲します。

スクリーンショット 2023-02-25 2.30.22.png

グラフを見てみると、「Discount」と「Did renew」の両方に影響を与える「Product need」が交絡となっていることが分かります。

「Discount(割引額)」の散布図が右肩下がりとなっているのは「割引すればするほど、更新する確率は低くなる」のではなく

  • 「Product need」が低いユーザーは(そもそも)更新する確率が低い
  • 「Product need」が低いユーザーに対して(更新を促すために)割引額を大きくする

ため、結果として「割引すればするほど、更新する確率は低くなる」ように見えていると考えられます。

これは、あくまでも「交絡の存在によって生じた(擬似的な)相関関係であり、因果関係ではない」ということに注意してください。相関関係と因果関係の違いについては、下記の記事を参照してください。

予測タスクと因果タスク

今までの説明を見ていると「SHAPを用いて求められる値(以下、SHAP値)は無意味なのではないか」と思われるかもしれませんが、そういうわけではありません。目的によります。

例えば、「更新するかどうかを予測する精度を高めて将来の収益を推定し、財務計画の意思決定に繋げたい」というような目的であれば、SHAP値が高い特徴量を知ることで予測精度の向上に役立てることができます。この例のようなタスクを予測タスクと呼びます。

これに対して、「できるだけ多くのユーザーの更新確率を上げるようなアクションに繋がる意思決定を考えたい」というような目的であれば、SHAP値ではなく、ある特徴量$X$を操作したときに目的変数(更新する確率)はどれほど変化するのかを知る必要があります。この例のようなタスクを因果タスクと呼びます。

SHAPで求められるモデル内における特徴量の重要度(モデルへの影響度)をSHAP値と呼称しているのは、SHAP値とShapley valueが異なる指標であるからです。

因果効果の推定にチャレンジ

まず、SHAP値に加えて真の因果効果もプロットします。

# SHAP値と真の因果効果をプロット
def marginal_effects(generative_model, num_samples=100, columns=None, max_points=20, logit=True, seed=0):
    """ Helper function to compute the true marginal causal effects.
    """
    X = generative_model(num_samples)
    if columns is None:
        columns = X.columns
    ys = [[] for _ in columns]
    xs = [X[c].values for c in columns]
    xs = np.sort(xs, axis=1)
    xs = [xs[i] for i in range(len(xs))]
    for i,c in enumerate(columns):
        xs[i] = np.unique([np.nanpercentile(xs[i], v, interpolation='nearest') for v in np.linspace(0, 100, max_points)])
        for x in xs[i]:
            Xnew = generative_model(num_samples, fixed={c: x}, seed=seed)
            val = Xnew["Did renew"].mean()
            if logit:
                val = scipy.special.logit(val)
            ys[i].append(val)
        ys[i] = np.array(ys[i])
    ys = [ys[i] - ys[i].mean() for i in range(len(ys))]
    return list(zip(xs, ys))

shap.plots.scatter(shap_values, ylabel="SHAP value\n(higher means more likely to renew)", overlay={
    "True causal effects": marginal_effects(generator, 10000, X.columns)
})

(出力結果)
download-4.png

青のプロットがSHAP値、黒の実線が真の因果効果を表しており、やはりいくつかの特徴量はSHAP値と真の因果効果が大きく異なっていることが分かります。

以下では、予測モデルや因果推論の手法を用いて因果効果を推定できるケース(あるいはできないケース)について記述しています。

予測モデルで因果関係を説明できるケース

「Economy(居住地域の経済状況)」という特徴量に注目してみます。

スクリーンショット 2023-02-25 18.00.09.png

因果グラフを見ると、「Economy」(青塗り部分)は「Did renew」とだけ繋がっており、モデル内の他の特徴量や未観測の変数と相関していません。このとき、「Economy」という特徴量は強い無視可能性があり、予測モデル(SHAP値)で因果関係を説明することができます。

「Economy」のSHAP値と真の因果効果のプロットを確認してみても
download-5.png

SHAP値のプロットが真の因果効果を概ね捉えていることが分かります。

今回はデータ生成過程が分かっているため、因果グラフを"正確に"記述することができ、予測モデル(SHAP値)で因果関係を説明することができます。しかし、モデル内の特徴量だけでなく、未観測な変数まで含めて因果グラフを記述できるというケースというのはあまり現実的ではありません。(これは因果推論タスク全般で言えることですが)

観察データから因果関係を説明できるケース

次は、予測モデルでは因果関係は説明できないものの、観察データから因果推論を行うことで因果関係を説明できるケースです。今回のデータでは2つのユースケースがあります。

交絡が観察されている場合

1つ目のユースケースは、交絡が観察されているケースです。「Ad spend(広告費用)」という特徴量に注目します。

スクリーンショット 2023-02-25 18.40.25.png

「Ad spend」(青塗り部分)は、観測可能な「Last upgrade」と「Monthly usage」(オレンジ塗り部分)という特徴量が交絡となっています。

「Ad spend」のSHAP値と真の因果効果のプロットを確認してみます。
download-6.png

真の因果効果はほとんど0にもかかわらず、SHAP値は右肩上がりとなっていることが分かります。

これは予測モデルが、「Last upgrade」や「Monthly usage」が予測に与える影響まで、「Ad spend」が予測に与える影響として捉えていることによって生じていると考えられます。

このような場合では、DML(Double/debiased Machine Learing)などの因果推論手法を用いて、バイアスを取り除いて因果効果を推定することができます。

# 必要なライブラリをインポート
from econml.dml import LinearDML
from sklearn.base import BaseEstimator, clone
import matplotlib.pyplot as plt

# DMLを実行し、因果効果の実値と推定値をプロットする関数を作成
class RegressionWrapper(BaseEstimator):
    """ Turns a classifier into a 'regressor'.

    We use the regression formulation of double ML, so we need to approximate the classifer
    as a regression model. This treats the probabilities as just quantitative value targets
    for least squares regression, but it turns out to be a reasonable approximation.
    """
    def __init__(self, clf):
        self.clf = clf

    def fit(self, X, y, **kwargs):
        self.clf_ = clone(self.clf)
        self.clf_.fit(X, y, **kwargs)
        return self

    def predict(self, X):
        return self.clf_.predict_proba(X)[:, 1]

def double_ml(y, causal_feature, control_features):
    """ Use doubleML from econML to estimate the slope of the causal effect of a feature.
    """
    xgb_model = xgboost.XGBClassifier(objective="binary:logistic", random_state=42)
    est = LinearDML(model_y=RegressionWrapper(xgb_model))
    est.fit(y, causal_feature, W=control_features)
    return est.effect_inference()

def plot_effect(effect, xs, true_ys, causal_feature, ylim=None):
    """ Plot a double ML effect estimate from econML as a line.

    Note that the effect estimate from double ML is an average effect *slope* not a full
    function. So we arbitrarily draw the slope of the line as passing through the origin.
    """
    plt.figure(figsize=(5, 3))

    pred_xs = [xs.min(), xs.max()]
    mid = (xs.min() + xs.max())/2
    pred_ys = [effect.pred[0]*(xs.min() - mid), effect.pred[0]*(xs.max() - mid)]

    plt.plot(xs, true_ys - true_ys[0], label='True causal effect', color="black", linewidth=3)
    point_pred = effect.point_estimate * pred_xs
    pred_stderr = effect.stderr * np.abs(pred_xs)
    plt.plot(pred_xs, point_pred - point_pred[0], label='Double ML slope', color=shap.plots.colors.blue_rgb, linewidth=3)
    plt.fill_between(pred_xs, point_pred - point_pred[0] - 3.291 * pred_stderr,
                     point_pred - point_pred[0] + 3.291 * pred_stderr, alpha=.2, color=shap.plots.colors.blue_rgb)
    plt.legend()
    plt.xlabel(causal_feature, fontsize=13)
    plt.ylabel("Zero centered effect")
    if ylim is not None:
        plt.ylim(*ylim)
    plt.gca().xaxis.set_ticks_position('bottom')
    plt.gca().yaxis.set_ticks_position('left')
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['top'].set_visible(False)
    plt.show()

# Ad spendの因果効果を推定
causal_feature = "Ad spend"
control_features = [
    "Sales calls", "Interactions", "Economy", "Last upgrade", "Discount",
    "Monthly usage", "Bugs reported"
]
effect = double_ml(y, X[causal_feature], X.loc[:,control_features])

# 真の因果効果とDMLによる因果効果の推定値をプロット
xs, true_ys = marginal_effects(generator, 10000, X[["Ad spend"]], logit=False)[0]
plot_effect(effect, xs, true_ys, causal_feature, ylim=(-0.2, 0.2))

(出力結果)
download-7.png

黒の実線が真の因果効果、青の実線が推定された因果効果(青塗り部分は99%信頼区間)を表しており、よく推定できているようです。

DML(Double Machine Learning)の詳細につきましては、下記の記事をご参照ください。

交絡は存在せず冗長性がある場合

2つ目のユースケースは、交絡は存在しないが冗長性が存在するケースです。「Sales calls(訪問販売の回数)」という特徴量に注目します。

スクリーンショット 2023-02-25 19.46.16.png

「Sales calls」(青塗り部分)は関連する特徴量はありますが、「Sales calls」に影響を与え、かつ、「Did renew」にも影響を与える特徴量は存在していないため、「Sales calls」には交絡は存在していません。

「Sales calls」のSHAP値と真の因果効果のプロットを確認してみます。
download-8.png

真の因果効果は右肩上がりですが、SHAP値はほぼ横ばいとなっています。

これは、モデルに冗長性があり、「Sales calls」がモデルに与える影響度が「Interactions」(オレンジ塗り部分)という特徴量がモデルに与える影響度に分散されているためだと思われます。

このようなケースでは、「Interactions」をモデルから取り除いて、「Sales calls」の因果効果を推定する必要があります。

# Interactionsを取り除いて因果効果を推定
causal_feature = "Sales calls"
control_features = [
    "Economy", "Last upgrade", "Discount", 
    "Monthly usage", "Ad spend", "Bugs reported"
]
effect = double_ml(y, X[causal_feature], X.loc[:,control_features])

# 真の因果効果とDMLによる因果効果の推定値をプロット
xs, true_ys = marginal_effects(generator, 10000, X[[causal_feature]], logit=False)[0]
plot_effect(effect, xs, true_ys, causal_feature, ylim=(-0.2, 0.2))

(出力結果)
download-9.png

黒の実線が真の因果効果、青の実線が推定された因果効果(青塗り部分は99%信頼区間)を表しています。よく推定できているとは言えませんが、少しは右肩上がりの傾向を捉えられているようです。

念のため注釈を入れさせていただきますが、こちらはかなり恣意的な解釈です。実際にこの結果を受けた際は、このように解釈するのはよくありません。笑

少しは右肩上がりの傾向を捉えられているようです。

予測モデルや観察データからは因果関係を説明できないケース

最後に「Discount(割引額)」という変数に注目します。

スクリーンショット 2023-02-25 20.16.38.png

「Discount」は、未観測である「Product need」が交絡となっています。

「Discount」のSHAP値と真の因果効果のプロットを確認してみます。
download-10.png

真の因果効果はわずかに右肩上がりですが、SHAP値は大きく右肩下がりとなっています。これは「Product need」の交絡の影響が「Discount」に分散されているためです。

このような場合、観察データから因果効果を推定するのは困難なケースがほとんどです。因果効果を推定したい場合には実験(RCT)などを行う必要があります。

今回のデータでは利用できませんが、操作変数法(IV法)や差分の差分法(DID)、回帰不連続デザイン(RDD)などの手法を用いて、(実験を行わずとも)観察データから因果効果を推定できるケースもあります。

おわりに

最後まで読んでいただきありがとうございました。

zennにて「Python×データ分析」をメインテーマに記事を執筆しているので、ご一読いただけますと幸いです。

また、過去にLTや勉強会で発表した資料が下記リンクにてまとめてありますので、こちらもぜひご一読くださいませ。

参考文献

Web上の資料は2023年2月25日時点のものです。

191
164
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
191
164