因果推論もとい因果関係を考察する上で必要になるのが「そうじゃなかった世界」というものです。
もちろん機械学習モデルにランダムな値を入れる事で実現できる可能性もありますが、例えば変数間の関係性(多重共線性など)やその他色々あって簡単に代入できるものではありません。
そこで、反実仮想のdice_ml(DiCE)を使います。
使用するデータはCox比例ハザードモデルのライブラリにある追跡データのrossiです。
各変数については以下のようになっています。
参考 Pythonで生存時間解析 〜Cox比例ハザードモデルの解説と実装〜
変数名 | 内容 |
---|---|
week | 釈放後の最初の逮捕までの週数、または打ち切り |
arrest | 逮捕の有無 (1:有) |
fin | 財政援助の有無 (1:有) |
age | 釈放時の年齢 |
race | 黒人かどうか (1:黒人) |
wexp | 労働経験の有無 (1:有) |
mar | 釈放時の婚姻状況 (1:結婚) |
paro | 仮釈放かどうか (1:仮釈放) |
prio | 過去の受刑回数 |
では、ここでは逮捕が逆転した場合どのようになるかデータフレームを実際に作ってみようと思います。
精度の検証
精度を確認してから訓練データで逮捕の逆転を見てみます。
from sklearn.model_selection import train_test_split as tts
from sklearn.metrics import classification_report
from sklearn.ensemble import GradientBoostingClassifier as GBC
import pandas as pd
import numpy as np
import dice_ml
df = pd.read_csv("rossi.csv")
y = df["arrest"]
x = df.drop(["arrest"], axis=1)
x_train, x_test, y_train, y_test = tts(x, y, test_size=0.2, random_state=1)
model = GBC()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
print(classification_report(y_test, y_pred))
precision recall f1-score support
0 1.00 0.99 0.99 69
1 0.95 1.00 0.97 18
accuracy 0.99 87
macro avg 0.97 0.99 0.98 87
weighted avg 0.99 0.99 0.99 87
逮捕の逆転のデータセット
精度も好調なのでここからは訓練用データで逮捕が逆転した場合どうなるのか見ていきます。
d = dice_ml.Data(dataframe = pd.concat([x_test, y_test], axis=1),
continuous_features = ["week", "age", "prio"],
outcome_name = "arrest"
)
m = dice_ml.Model(model=model, backend="sklearn")
exp = dice_ml.Dice(d, m)
pre_counter = x_train.iloc[0:, :]
dice_exp = exp.generate_counterfactuals(pre_counter, total_CFs=1, desired_class = "opposite")
df_dice = dice_exp.cf_examples_list[0].final_cfs_df
for i in range(1, len(dice_exp.cf_examples_list)):
df_dice = pd.concat([df_dice, dice_exp.cf_examples_list[i].final_cfs_df])
df_dice
これらが推論の材料になります。
訓練データと仮想データの比較
訓練データと仮想データで効果測定としてprio(受刑回数)にどのように違いがあるかを見てみます。
train = pd.concat([x_train, y_train], axis=1)
for col in df_dice.columns:
df_dice[col] = df_dice[col].astype(int)
効果測定の関数を作ります。
import scipy.stats as stats
import matplotlib.pyplot as plt
def effecttest2(df, columns, y_name, auto=False):
ave = []
dtr = []
pos = []
lab = []
x = 0
for col in columns:
values = list(set(df[col].values))
tmp_ave = []
tmp_dtr = []
tmp_pos = []
tmp_lab = []
for val in values:
df_tmp = df[df[col]==val]
tmp_ave.append(df_tmp[y_name].mean())
tmp_dtr.append(df_tmp[y_name])
tmp_pos.append(x)
tmp_lab.append(col+"_"+str(val))
x = x + 1
ave.append(tmp_ave)
dtr.append(tmp_dtr)
pos.append(tmp_pos)
lab.append(tmp_lab)
if auto:
for i in range(len(dtr)):
for j in range(len(dtr[i])):
for k in range(j, len(dtr[i])):
if j != k:
f, p = stats.bartlett(dtr[i][0], dtr[i][1])
if (2 * p) <= 0.05:
t, p = stats.ttest_ind(dtr[i][j], dtr[i][k], equal_var=False)
else:
t, p = stats.ttest_ind(dtr[i][j], dtr[i][k], equal_var=True)
print(lab[i][j], lab[i][k])
print("t = %f"%(t))
print("p = %f"%(p))
print("val = %f"%(ave[i][j]-ave[i][k]))
print()
for i in range(len(dtr)):
plt.boxplot(dtr[i], positions=pos[i], labels=lab[i])
plt.plot(pos[i], ave[i], marker="x")
plt.xticks(rotation=90)
plt.show()
return ave, dtr, pos , lab
ave, dtr, pos , lab = effecttest2(train, ["fin", "race", "wexp", "mar", "paro", "arrest"], "prio", auto=True)
fin_0 fin_1
t = 0.132881
p = 0.894366
val = 0.041916
race_0 race_1
t = 2.657921
p = 0.008231
val = 1.295492
wexp_0 wexp_1
t = 4.821392
p = 0.000003
val = 1.569231
mar_0 mar_1
t = 0.758046
p = 0.448944
val = 0.372951
paro_0 paro_1
t = 2.514523
p = 0.012674
val = 0.876872
arrest_0 arrest_1
t = -2.935298
p = 0.003962
val = -1.226029
ave, dtr, pos , lab = effecttest2(df_dice, ["fin", "race", "wexp", "mar", "paro", "arrest"], "prio", auto=True)
fin_0 fin_1
t = -0.148670
p = 0.881901
val = -0.051271
race_0 race_1
t = 1.089637
p = 0.276638
val = 0.488430
wexp_0 wexp_1
t = 3.468914
p = 0.000609
val = 1.228205
mar_0 mar_1
t = 0.566045
p = 0.571733
val = 0.266458
paro_0 paro_1
t = 2.428683
p = 0.015956
val = 0.911929
arrest_0 arrest_1
t = 3.446230
p = 0.000768
val = 1.531789
もちろん逮捕について逆にしているのでデータが異なるのは当たり前と考えてください。
(今回はたまたま平均値の差が似ていますが)。
なので、極端に陽性または陰性や介入または非介入や曝露または非曝露のデータが少ない時はこのように反実仮想を使ってデータを作り因果推論をしてみるといいかもしれません。
まとめ
これらの「あったかもしれない世界」が高精度で作れることで、因果推論の材料を作る事ができるかもしれません。
それでなくても因果関係の考察は難しいのでこういったライブラリはありがたい限りです。