3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

2項分類:しきい値を自動で計算する

Last updated at Posted at 2022-03-31

1.はじめに

OKデータとNGデータが不均衡の場合、分類器の分析結果から、OKとNGを分けるしきい値を自動で探す方法の説明です。

2.やりたいこと

下記のようなOKデータとNGデータがあるとします。
この分類器を配布する際にしきい値を決めるんですが、それをスマートなやり方で行いたいです。
Figure_1.png

3.説明

3.1データの作成

まず、numpy.random()でOK、NGデータを作成します。
ここでは、下記のような条件で作成します。

  1. OKデータの数:10000個
  2. NGデータの数: 1000個
  3. OKデータのスコアの平均:2
  4. NGデータのスコアの平均:4
  5. OKデータの標準偏差:0.5
  6. NGデータの標準偏差: 0.8
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score, precision_recall_curve, confusion_matrix, plot_confusion_matrix

from make_random_data import data_array
from make_plot import drawplots

#variable

number_OK = 1e4
number_NG = 1e3
OK_mean = 2
NG_mean = 4
OK_stdev = 0.5
NG_stdev = 0.8
#1. Data prepation
data = data_array(n_OK=number_OK, n_NG=number_NG, mean_OK=OK_mean, mean_NG=NG_mean, std_OK=OK_stdev, std_NG=NG_stdev)
Score_data,Label_data, OK_score, NG_score =  data.make()

描画のためのインスタンスを用意します。

plots = drawplots(Score_data=Score_data, Label_data=Label_data, OK_score=OK_score, NG_score=NG_score)
plots.draw_histogram(threshold_opt=None)

3.2計算 (ROC曲線、AUCスコア、f1スコア)

ROC曲線とAUCスコアを計算します。

#2.ROC Curve
fpr, tpr, threshold = roc_curve(y_true = Label_data, y_score = Score_data)

#2.1.AUC score
roc_auc_value = roc_auc_score(y_true = Label_data, y_score = Score_data)

f1スコアを計算します。

#3.f1 Value
precision, recall, threshold_from_pr = precision_recall_curve(y_true = Label_data, probas_pred = Score_data)
a = 2* precision * recall
b = precision + recall
f1 = np.divide(a,b,out=np.zeros_like(a), where=b!=0)

3.3最適な閾値の計算

f1スコアを最大にするインデクスを計算します。

#4. Optimal Value
#4.1 find optimal threshold
idx_opt = np.argmax(f1)

ここで計算が少しややこしくなりますが、次の計算のために何らかの前処理を行います。

#これがやっかい
threshold_opt = threshold_from_pr[idx_opt] #Confusion Matrix
idx_opt_from_pr = np.where(threshold == threshold_opt) # ROC Curve

4.結果

必要なグラフを描画します。

#5 draw f1 score
plots.draw_f1_score(threshold_from_pr=threshold_from_pr, f1=f1)

#6.draw precision recall curve
plots.draw_precision_recall(precision=precision, recall=recall)

#7. Plot ROC curve
plots.draw_ROC_curve(fpr=fpr, tpr=tpr, opt_idx=idx_opt_from_pr, roc_auc_value=roc_auc_value)

#8. Draw Confusion Matrix
plots.draw_confusion_matrix(Score_data=Score_data, Label_data=Label_data, threshold_opt=threshold_opt)

#09. Draw Histogram with Threshold
plots.draw_histogram(threshold_opt=threshold_opt)

plt.show()

4.1 F1スコア with 最適閾値

F1スコアを最大にするしきい値を最適値にします。この場合、Thresholdは 3.245です。
Figure_2.png

4.2 ROC曲線 with 最適閾値

最適なThresholdの部分を赤い点に表示します。
Figure_4.png

4.3 データ分布 with 最適閾値

Thresholdは 3.245です。
Figure_6.png

4.4 混同行列

Figure_5.png

4.5 Precision - Recall 曲線

Figure_3.png

5.まとめ

分類問題の課題である「最適なしきい値を探す」方法について説明しました。

6.全体プログラムコード

本体プログラム

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score, precision_recall_curve, confusion_matrix, plot_confusion_matrix

from make_random_data import data_array
from make_plot import drawplots

#variable

number_OK = 1e4
number_NG = 1e3
OK_mean = 2
NG_mean = 4
OK_stdev = 0.5
NG_stdev = 0.8

#1. Data prepation
data = data_array(n_OK=number_OK, n_NG=number_NG, mean_OK=OK_mean, mean_NG=NG_mean, std_OK=OK_stdev, std_NG=NG_stdev)
Score_data,Label_data, OK_score, NG_score =  data.make()

plots = drawplots(Score_data=Score_data, Label_data=Label_data, OK_score=OK_score, NG_score=NG_score)
plots.draw_histogram(threshold_opt=None)

print(Score_data)
print(Label_data)

#2.ROC Curve
fpr, tpr, threshold = roc_curve(y_true = Label_data, y_score = Score_data)

#2.1.AUC score
roc_auc_value = roc_auc_score(y_true = Label_data, y_score = Score_data)

#3.f1 Value
precision, recall, threshold_from_pr = precision_recall_curve(y_true = Label_data, probas_pred = Score_data)
a = 2* precision * recall
b = precision + recall
f1 = np.divide(a,b,out=np.zeros_like(a), where=b!=0)

#4. Optimal Value
# Optimal value routine is hard to understand later
#4.1 find optimal threshold
idx_opt = np.argmax(f1)

#これがやっかい
threshold_opt = threshold_from_pr[idx_opt] #Confusion Matrix
idx_opt_from_pr = np.where(threshold == threshold_opt) # ROC Curve

#5 draw f1 score
plots.draw_f1_score(threshold_from_pr=threshold_from_pr, f1=f1)

#6.draw precision recall curve
plots.draw_precision_recall(precision=precision, recall=recall)

#7. Plot ROC curve
plots.draw_ROC_curve(fpr=fpr, tpr=tpr, opt_idx=idx_opt_from_pr, roc_auc_value=roc_auc_value)

#8. Draw Confusion Matrix
plots.draw_confusion_matrix(Score_data=Score_data, Label_data=Label_data, threshold_opt=threshold_opt)

#09. Draw Histogram with Threshold
plots.draw_histogram(threshold_opt=threshold_opt)

plt.show()
make_random_data.py データ生成プログラム
import numpy as np
import matplotlib.pyplot as plt


class data_array():
    def __init__(self, n_OK=1e4, n_NG=1e4, mean_OK=2.5, mean_NG=4, std_OK=0.5, std_NG=0.5):
        self.n_OK = int(n_OK)
        self.n_NG = int(n_NG)
        self.mean_OK = mean_OK
        self.mean_NG = mean_NG
        self.std_OK = std_OK
        self.std_NG = std_NG

        self.Score_data = None
        self.Label_data = None

    def make(self):
        self._make_score()
        self._make_label()

        return self.Score_data, self.Label_data, self.OK_score, self.NG_score

    def _make_score(self):
        self.OK_score = np.random.normal(loc=self.mean_OK, scale=self.std_OK, size=self.n_OK)
        self.NG_score = np.random.normal(loc=self.mean_NG, scale=self.std_NG, size=self.n_NG)
        self.Score_data = np.concatenate((self.OK_score, self.NG_score))

    def _make_label(self):
        self.OK_label = np.array([0 for i in range(self.n_OK)])
        self.NG_label = np.array([1 for i in range(self.n_NG)])
        self.Label_data = np.concatenate((self.OK_label, self.NG_label))
  
make_plot.py グラフ生成プログラム
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

class drawplots():
    def __init__(self, Score_data, Label_data, OK_score, NG_score):
        self.Score_data = Score_data
        self.Label_data = Label_data
        self.OK_score = OK_score
        self.NG_score = NG_score

    def draw_histogram(self, threshold_opt = None):
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)

        ax.hist(self.OK_score, bins=100, color='blue', alpha=0.3, density=False, label='OK')
        ax.hist(self.NG_score, bins=100, color='red', alpha=0.3, density=False, label='NG')

        if threshold_opt:
            ax.axvline(x=threshold_opt, color='green', linestyle='dashed', linewidth=2)
            ax.set_title(f'OK,NG Data Distribution, threshold: {round(threshold_opt, 2)}')
        else:
            pass

        ax.set_title(f'OK,NG Data Distribution')
        ax.set_xlabel('Score')
        ax.set_ylabel('frequency')
        ax.grid()
        ax.legend()

        return fig

    def draw_ROC_curve(self, fpr, tpr, opt_idx, roc_auc_value):
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)

        ax.plot(fpr, tpr, marker='o')
        ax.plot(fpr[opt_idx], tpr[opt_idx], marker='o', color='red')
        ax.plot(np.linspace(1, 0, len(fpr)), np.linspace(1, 0, len(fpr)), linestyle='--', color='gray')

        ax.set_title(f'ROC Curve: {round(roc_auc_value, 3)}')
        ax.set_xlabel('FPR: False positive rate')
        ax.set_ylabel('TPR: True positive rate')

        ax.grid()
        ax.legend()
        return




    def draw_f1_score(self,threshold_from_pr, f1):

        #find the optimal
        idx_opt = np.argmax(f1)

        #make thresholds_new_array
        threshold2_new = np.append(threshold_from_pr, threshold_from_pr[-1])
        threshold_opt = threshold_from_pr[idx_opt]
        #draw

        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(threshold2_new, f1)
        ax.axvline(x=threshold_opt, color='green', linestyle='dashed', linewidth=2)
        ax.plot(threshold2_new[idx_opt], f1[idx_opt], marker='o', color='red')
        ax.set_title(f'F1 Score: {round(threshold_opt, 3)}')
        ax.set_xlabel('Threshold')
        ax.set_ylabel('F1 Score')
        ax.grid()

        return fig

    def draw_confusion_matrix(self, Score_data, Label_data, threshold_opt, ):

        #make a predicted label
        predicted_label = []
        for kk in range(Score_data.shape[0]):

            if Score_data[kk] >= threshold_opt:
                predicted_label.append(1)  # NG
            else:
                predicted_label.append(0)

        # label_class
        classes = ['OK', 'NG']
        cm = confusion_matrix(y_true=Label_data, y_pred=predicted_label)
        cm = cm / cm.astype(np.float64).sum(axis=1) * 100

        #plot
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)

        cmap = plt.cm.Blues
        ax.imshow(cm, cmap=cmap)
        for m in range(cm.shape[1]):
            for n in range(cm.shape[0]):
                ax.text(x=n, y=m, s=round(cm[m, n], 2), va='center', ha='center', color="gray")
        ax.set_title(f'Confusion Matrix, Optimal Threshold : {round(threshold_opt, 2)}')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        tick_marks = np.arange(len(classes))
        ax.set_xticks(tick_marks)
        ax.set_yticks(tick_marks)
        ax.set_xticklabels(classes)
        ax.set_yticklabels(classes)

        return fig

    def draw_precision_recall(self, precision, recall):
        # plot
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(precision, recall)
        ax.set_title(f'Precision Recall Curve')
        ax.set_xlabel('Precision')
        ax.set_ylabel('Recall')
        ax.grid()

        return fig


3
3
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?