ROC曲線のテンプレート
- 自分用に作成している勉強用・備忘録のノートです。
- 専門的な内容は有りません。
- 間違いなどが有りましたらコメントよろしくおねがいします。
目的:
- コードのテンプレート化
とりあえず実装(statsmodels)
ここでは、statsmodelsのライブラリを用いて実装。
まず必要なライブラリの読み込み
import numpy as np
import matplotlib.pyplot as plt
# ROC 曲線を描くためのライブラリ
from sklearn.metrics import roc_curve
# AUC を求めるためのライブラリ
from sklearn import metrics
データを準備
# 特別ソートなどは必要ない
# 正解ラベル
y_true = [0, 0, 0, 0, 1, 1, 1, 1]
# 予測スコア
y_score = [0.2, 0.3, 0.6, 0.8, 0.4, 0.5, 0.7, 0.9]
# matplotlib の roc_curve を用いるとプロット用の出力が得られる。
fpr, tpr, thresholds = roc_curve(y_true, y_score)
# FPR(偽陽性率)
print(fpr)
# TPR(真陽性率)
print(tpr)
# 閾値
print(thresholds)
AUC の計算
AUC(area under curve; ROC曲線下の面積)
# AUC の計算
auc = metrics.auc(fpr, tpr)
ROC 曲線をプロット
matplotlib でROC 曲線をプロット
fig = plt.figure()
ax = fig.add_subplot()
# FPR を横軸, TPRを縦軸に設定する。
ax.plot(fpr, tpr, marker='o', label='ROC curve (area = %.2f)'%auc)
ax.plot(np.linspace(1, 0, len(fpr)), np.linspace(1, 0, len(fpr)),
label='Random curve (area = %.2f)'%0.5, linestyle = '--',)
ax.plot((0,0,1),(0,1,1),
label='Ideal curve (area = %.2f)'%1.0, linestyle = '--',)
plt.title("ROC curve")
plt.xlabel('FPR: False positive rate')
plt.ylabel('TPR: True positive rate')
plt.grid()
plt.legend()
plt.show()
最終的にこのような画像がプロットされる
少しだけ知識の整理
混合行列の要素
- 真陰性(True Negative, TN):検査で陰性とされ、実際に陰性である場合
- 偽陰性(False Negative, FN):検査で陰性とされたものの、実際には陽性である場合
- 偽陽性(False Positive, FP):検査で陽性とされたものの、実際には陰性である場合
- 真陽性(True Positive, TP):検査で陽性とされ、実際に陽性である場合
適合率 (Precision)
$\text{Precision} = \displaystyle \frac{TP}{TP + FP} = \frac{\text{予測が陽性で実際に陽性}}{予測が陽性としたものすべて}$
使い所は誤って陽性と判断しては困る状況。例えばがんの検出など.
使い方(解釈)
ROC曲線のAUCが低いからと言って作成したモデルが有益ではないと決めるのは早計である。
例えば、閾値以上の中に陰性データが多いほど利益が高くなり、陽性データが閾値以上のグループに入ればその分大きく損害が発生するシステムを考える。ROC曲線の立ち上がりが良いほど陽性データを弾くための閾値を小さく設定しよりたくさんの陰性データを閾値以上のグループに仕分ける事ができる。逆に同じAUCでも立ち上がりが遅いROC曲線だとその分閾値を大きく設定する必要があるため、閾値以上に入る陰性データが少なくなり結果として利益を生み出すことができない。
このように問題となる背景によってはAUCの値だけではなくROC曲線を確認することが重要になることが分かる。
付録(PR曲線)
## PR曲線
precision, recall, thresholds = metrics.precision_recall_curve(y_true, y_score)
auc = metrics.auc(recall, precision)
print(auc)
plt.plot(recall, precision, label='PR curve (area = %.2f)'%auc)
plt.legend()
plt.title('PR curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.grid(True)
plt.show()