0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ROC 曲線とAUCの自作コード

Last updated at Posted at 2019-10-17

機械学習時など,モデルの性能評価で必要なROC, AUCですが
cutoff ポイントを自作したいとき,不便なので自分で作ってみました。
(もしかして誤解があったら,気づいた方は教えてください!)

For ROC curve after model construction

ROC.py

# Code for CutOff_and_ROCcurve.
Predict_arr = model.predict(test_arr)

cutoff_list = np.arange(0, 1.001, 0.001)

cut_sen_spe = []
for cutoff in cutoff_list:
    pred_answer_arr = Predict_arr.copy()
    pred_answer_arr[pred_answer_arr >= cutoff] = 1
    pred_answer_arr[pred_answer_arr < cutoff] = 0
    PP = 0
    PN = 0
    NP = 0
    NN = 0
    for i in range(len(test_answer_arr)):
        if test_answer_arr[i] == 1 and pred_answer_arr[i] == 1:
            PP += 1
        if test_answer_arr[i] == 1 and pred_answer_arr[i] == 0:
            PN += 1
        if test_answer_arr[i] == 0 and pred_answer_arr[i] == 1:
            NP += 1
        if test_answer_arr[i] == 0 and pred_answer_arr[i] == 0:
            NN += 1
    cut_sen_spe.append([cutoff, PP, PN, NP, NN])

total_list = []
for j in range(len(cut_sen_spe)):
    sen = cut_sen_spe[j][1] / (cut_sen_spe[j][1] + cut_sen_spe[j][2])
    spe = cut_sen_spe[j][4] / (cut_sen_spe[j][3] + cut_sen_spe[j][4])
    total_list.append([sen, spe])

total_arr = np.array(total_list)

sen_arr = total_arr[:, 0]
spe_arr = total_arr[:, 1]


# Visualization & Save ROCcurve

f = plt.figure(figsize = (10, 10))
plt.plot(1-spe_arr, sen_arr, c = "r", linewidth = 3.0)
plt.plot([0.0, 1.0], [0.0, 1.0], linestyle='dashed')
plt.xticks(np.arange(0, 1.01, 0.1))
plt.yticks(np.arange(0, 1.01, 0.1))
plt.xlim(-0.03, 1.03)
plt.ylim(-0.03, 1.03)
plt.savefig(ROCcurve_Save)
plt.show()

# AUC 
target_falPos_arr = 1-spe_arr
target_sen_arr = sen_arr
target_falPos_arr.sort()
target_sen_arr.sort()

AUC_segIMG = 0
for ind in range(len(target_falPos_arr[:-1])):
    X_p = target_falPos_arr[ind]
    nex_X = target_falPos_arr[ind+1]
    Y_p = target_sen_arr[ind]
    nex_Y = target_sen_arr[ind+1]
    area = (nex_X - X_p) * Y_p + (nex_X - X_p) * (nex_Y - Y_p) / 2
    AUC_segIMG += area

print(AUC_segIMG)
0
3
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?