1
0

反実仮想で、あったかもしれない世界を

Posted at

因果推論もとい因果関係を考察する上で必要になるのが「そうじゃなかった世界」というものです。
もちろん機械学習モデルにランダムな値を入れる事で実現できる可能性もありますが、例えば変数間の関係性(多重共線性など)やその他色々あって簡単に代入できるものではありません。
そこで、反実仮想のdice_ml(DiCE)を使います。
使用するデータはCox比例ハザードモデルのライブラリにある追跡データのrossiです。
各変数については以下のようになっています。
参考 Pythonで生存時間解析 〜Cox比例ハザードモデルの解説と実装〜

変数名 内容
week 釈放後の最初の逮捕までの週数、または打ち切り
arrest 逮捕の有無 (1:有)
fin 財政援助の有無 (1:有)
age 釈放時の年齢
race 黒人かどうか (1:黒人)
wexp 労働経験の有無 (1:有)
mar 釈放時の婚姻状況 (1:結婚)
paro 仮釈放かどうか (1:仮釈放)
prio 過去の受刑回数

では、ここでは逮捕が逆転した場合どのようになるかデータフレームを実際に作ってみようと思います。

精度の検証

精度を確認してから訓練データで逮捕の逆転を見てみます。

from sklearn.model_selection import train_test_split as tts
from sklearn.metrics import classification_report
from sklearn.ensemble import GradientBoostingClassifier as GBC
import pandas as pd
import numpy as np
import dice_ml

df = pd.read_csv("rossi.csv")
y = df["arrest"]
x = df.drop(["arrest"], axis=1)

x_train, x_test, y_train, y_test = tts(x, y, test_size=0.2, random_state=1)
model = GBC()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

           0       1.00      0.99      0.99        69
           1       0.95      1.00      0.97        18

    accuracy                           0.99        87
   macro avg       0.97      0.99      0.98        87
weighted avg       0.99      0.99      0.99        87

逮捕の逆転のデータセット

精度も好調なのでここからは訓練用データで逮捕が逆転した場合どうなるのか見ていきます。

d = dice_ml.Data(dataframe = pd.concat([x_test, y_test], axis=1),
                 continuous_features = ["week", "age", "prio"], 
                 outcome_name = "arrest"
                )
m = dice_ml.Model(model=model, backend="sklearn")
exp = dice_ml.Dice(d, m)
pre_counter = x_train.iloc[0:, :] 
dice_exp = exp.generate_counterfactuals(pre_counter, total_CFs=1, desired_class = "opposite")
df_dice = dice_exp.cf_examples_list[0].final_cfs_df
for i in range(1, len(dice_exp.cf_examples_list)):
    df_dice = pd.concat([df_dice, dice_exp.cf_examples_list[i].final_cfs_df])
df_dice

image.png

これらが推論の材料になります。

訓練データと仮想データの比較

訓練データと仮想データで効果測定としてprio(受刑回数)にどのように違いがあるかを見てみます。

train = pd.concat([x_train, y_train], axis=1)
for col in df_dice.columns:
    df_dice[col] = df_dice[col].astype(int)

効果測定の関数を作ります。

import scipy.stats as stats
import matplotlib.pyplot as plt
def effecttest2(df, columns, y_name, auto=False):
    ave = []
    dtr = []
    pos = []
    lab = []
    x = 0
    for col in columns:
        values = list(set(df[col].values))
        tmp_ave = []
        tmp_dtr = []
        tmp_pos = []
        tmp_lab = []
        for val in values:
            df_tmp = df[df[col]==val]
            tmp_ave.append(df_tmp[y_name].mean())
            tmp_dtr.append(df_tmp[y_name])
            tmp_pos.append(x)
            tmp_lab.append(col+"_"+str(val))
            x = x + 1
        ave.append(tmp_ave)
        dtr.append(tmp_dtr)
        pos.append(tmp_pos)
        lab.append(tmp_lab)
    if auto:
        for i in range(len(dtr)):
            for j in range(len(dtr[i])):
                for k in range(j, len(dtr[i])):
                    if j != k:
                        f, p = stats.bartlett(dtr[i][0], dtr[i][1])
                        if (2 * p) <= 0.05:
                            t, p = stats.ttest_ind(dtr[i][j], dtr[i][k], equal_var=False)
                        else:
                            t, p = stats.ttest_ind(dtr[i][j], dtr[i][k], equal_var=True)
                        print(lab[i][j], lab[i][k])
                        print("t   = %f"%(t))
                        print("p   = %f"%(p))
                        print("val = %f"%(ave[i][j]-ave[i][k]))
                        print()
        for i in range(len(dtr)):
            plt.boxplot(dtr[i], positions=pos[i], labels=lab[i])
            plt.plot(pos[i], ave[i], marker="x")
        plt.xticks(rotation=90)
        plt.show()
    return ave, dtr, pos , lab
ave, dtr, pos , lab = effecttest2(train, ["fin", "race", "wexp", "mar", "paro", "arrest"], "prio", auto=True)
fin_0 fin_1
t   = 0.132881
p   = 0.894366
val = 0.041916

race_0 race_1
t   = 2.657921
p   = 0.008231
val = 1.295492

wexp_0 wexp_1
t   = 4.821392
p   = 0.000003
val = 1.569231

mar_0 mar_1
t   = 0.758046
p   = 0.448944
val = 0.372951

paro_0 paro_1
t   = 2.514523
p   = 0.012674
val = 0.876872

arrest_0 arrest_1
t   = -2.935298
p   = 0.003962
val = -1.226029

Untitled.png

ave, dtr, pos , lab = effecttest2(df_dice, ["fin", "race", "wexp", "mar", "paro", "arrest"], "prio", auto=True)
fin_0 fin_1
t   = -0.148670
p   = 0.881901
val = -0.051271

race_0 race_1
t   = 1.089637
p   = 0.276638
val = 0.488430

wexp_0 wexp_1
t   = 3.468914
p   = 0.000609
val = 1.228205

mar_0 mar_1
t   = 0.566045
p   = 0.571733
val = 0.266458

paro_0 paro_1
t   = 2.428683
p   = 0.015956
val = 0.911929

arrest_0 arrest_1
t   = 3.446230
p   = 0.000768
val = 1.531789

Untitled.png
もちろん逮捕について逆にしているのでデータが異なるのは当たり前と考えてください。
(今回はたまたま平均値の差が似ていますが)。
なので、極端に陽性または陰性や介入または非介入や曝露または非曝露のデータが少ない時はこのように反実仮想を使ってデータを作り因果推論をしてみるといいかもしれません。

まとめ

これらの「あったかもしれない世界」が高精度で作れることで、因果推論の材料を作る事ができるかもしれません。
それでなくても因果関係の考察は難しいのでこういったライブラリはありがたい限りです。

参考文献

Pythonで生存時間解析 〜Cox比例ハザードモデルの解説と実装〜

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0