1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

二値分類そのままやるべからず(関数作った)

Last updated at Posted at 2024-05-16

二値分類は基本的に0.5を境に0.5以上だと陽性、未満だと陰性とみなすことが多いですが、必ずしもそれが正しいとは言い切れません。
というのも正解率がどれだけ高くても再現率と適合率のバランスが悪かったりするとダメですし、元のデータの偏りも考慮しないといけません。
(例えばですが陽性90個と陰性10個のデータがあって全部陽性と診断しても正解率単体では90%でも陰性に着目すると評価は悲惨(陰性反応的中率0%で陰性再現率0%)なんですよ)
そこで最適なしきい値を考察する関数を作ってみました。

関数

from sklearn.metrics import classification_report
import numpy as np
import matplotlib.pyplot as plt
def optimal_threshold(y, pred, n="0.0", p="1.0"):
    tsd = np.linspace(0.1, 0.9, 100)
    accs = []
    rec0 = []
    rec1 = []
    pcs0 = []
    pcs1 = []
    f1_0 = []
    f1_1 = []
    for i in range(len(tsd)):
        y_pred = np.where(pred >= tsd[i], 1, 0)
        rep = classification_report(y, y_pred, output_dict=True)
        accs.append(rep["accuracy"])
        rec0.append(rep[n]["recall"])
        rec1.append(rep[p]["recall"])
        pcs0.append(rep[n]["precision"])
        pcs1.append(rep[p]["precision"])
        f1_0.append(rep[n]["f1-score"])
        f1_1.append(rep[p]["f1-score"])
    plt.plot(tsd, accs, label="accuracy")
    plt.plot(tsd, rec0, label="recall 0")
    plt.plot(tsd, rec1, label="recall 1")
    plt.plot(tsd, pcs0, label="precision 0")
    plt.plot(tsd, pcs1, label="precision 1")
    plt.plot(tsd, f1_0, label="f1-score 0")
    plt.plot(tsd, f1_1, label="f1-score 1")
    plt.legend()
    plt.show()
    return tsd[np.argmax(accs)], tsd[np.argmax(rec0)], tsd[np.argmax(rec1)], tsd[np.argmax(pcs0)], tsd[np.argmax(pcs1)], tsd[np.argmax(f1_0)], tsd[np.argmax(f1_1)]

返り値の順は正解率・陰性の再現率・陽性の再現率・陰性の適合率・陽性の適合率・陰性のF1・陽性のF1になります。

使用例

import pandas as pd
import statsmodels.api as sm

df = pd.read_csv("breast_cancer.csv")

y = df["y"]
x = df.drop("y", axis=1)
for col in x.columns:
    x[col] = (x[col] - x[col].mean()) / x[col].std()

model = sm.Logit(y, x).fit_regularized()
otp = optimal_threshold(y, model.predict(x), n="0.0", p="1.0")
otp

Untitled.png

(0.398989898989899,
 0.5767676767676768,
 0.1,
 0.1,
 0.5767676767676768,
 0.398989898989899,
 0.398989898989899)

なので正解率とF1からおおよそ0.4くらいをしきい値にするとちょうどよいかもしれません。
では実際に混合行列を見てみましょう

from sklearn.metrics import confusion_matrix

pred = model.predict(x)
y_pred1 = np.where(pred >= 0.5, 1, 0)
y_pred2 = np.where(pred >= 0.398989898989899, 1, 0)

しきい値が0.5(デフォルト値)

pd.DataFrame(confusion_matrix(y, y_pred1))

image.png

しきい値を最適値にした場合

pd.DataFrame(confusion_matrix(y, y_pred2))

image.png

陽性を正しく陽性と完璧に判断でき、再現率が向上したことが分かります。

まとめ

0.5が最適とは限らない。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?