LoginSignup
3
3

ROC 曲線

Last updated at Posted at 2022-05-09

ROC曲線のテンプレート

  • 自分用に作成している勉強用・備忘録のノートです。
  • 専門的な内容は有りません。
  • 間違いなどが有りましたらコメントよろしくおねがいします。

目的:

  • コードのテンプレート化

とりあえず実装(statsmodels)

ここでは、statsmodelsのライブラリを用いて実装。

まず必要なライブラリの読み込み

import numpy as np 
import matplotlib.pyplot as plt

# ROC 曲線を描くためのライブラリ
from sklearn.metrics import roc_curve
# AUC を求めるためのライブラリ 
from sklearn import metrics

データを準備

# 特別ソートなどは必要ない
# 正解ラベル
y_true = [0, 0, 0, 0, 1, 1, 1, 1]
# 予測スコア
y_score = [0.2, 0.3, 0.6, 0.8, 0.4, 0.5, 0.7, 0.9]

# matplotlib の roc_curve を用いるとプロット用の出力が得られる。
fpr, tpr, thresholds = roc_curve(y_true, y_score)

# FPR(偽陽性率)
print(fpr)
# TPR(真陽性率)
print(tpr)
# 閾値
print(thresholds)

AUC の計算

AUC(area under curve; ROC曲線下の面積)

# AUC の計算
auc = metrics.auc(fpr, tpr)

ROC 曲線をプロット

matplotlib でROC 曲線をプロット

fig = plt.figure()
ax = fig.add_subplot()

# FPR を横軸, TPRを縦軸に設定する。

ax.plot(fpr, tpr, marker='o',  label='ROC curve (area = %.2f)'%auc)

ax.plot(np.linspace(1, 0, len(fpr)), np.linspace(1, 0, len(fpr)), 
        label='Random curve (area = %.2f)'%0.5, linestyle = '--',)

ax.plot((0,0,1),(0,1,1), 
        label='Ideal curve (area = %.2f)'%1.0, linestyle = '--',)

plt.title("ROC curve")
plt.xlabel('FPR: False positive rate')
plt.ylabel('TPR: True positive rate')
plt.grid()
plt.legend()

plt.show()

最終的にこのような画像がプロットされる

roc.png

少しだけ知識の整理

混合行列の要素

  • 真陰性(True Negative, TN):検査で陰性とされ、実際に陰性である場合
  • 偽陰性(False Negative, FN):検査で陰性とされたものの、実際には陽性である場合
  • 偽陽性(False Positive, FP):検査で陽性とされたものの、実際には陰性である場合
  • 真陽性(True Positive, TP):検査で陽性とされ、実際に陽性である場合

適合率 (Precision)

$\text{Precision} = \displaystyle \frac{TP}{TP + FP} = \frac{\text{予測が陽性で実際に陽性}}{予測が陽性としたものすべて}$
使い所は誤って陽性と判断しては困る状況。例えばがんの検出など.

使い方(解釈)

ROC曲線のAUCが低いからと言って作成したモデルが有益ではないと決めるのは早計である。
例えば、閾値以上の中に陰性データが多いほど利益が高くなり、陽性データが閾値以上のグループに入ればその分大きく損害が発生するシステムを考える。ROC曲線の立ち上がりが良いほど陽性データを弾くための閾値を小さく設定しよりたくさんの陰性データを閾値以上のグループに仕分ける事ができる。逆に同じAUCでも立ち上がりが遅いROC曲線だとその分閾値を大きく設定する必要があるため、閾値以上に入る陰性データが少なくなり結果として利益を生み出すことができない。

このように問題となる背景によってはAUCの値だけではなくROC曲線を確認することが重要になることが分かる。

付録(PR曲線)

## PR曲線
precision, recall, thresholds = metrics.precision_recall_curve(y_true, y_score)

auc = metrics.auc(recall, precision)
print(auc)

plt.plot(recall, precision, label='PR curve (area = %.2f)'%auc)
plt.legend()
plt.title('PR curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.grid(True)
plt.show()

roc.png

参考文献

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3