1.はじめに
OKデータとNGデータが不均衡の場合、分類器の分析結果から、OKとNGを分けるしきい値を自動で探す方法の説明です。
2.やりたいこと
下記のようなOKデータとNGデータがあるとします。
この分類器を配布する際にしきい値を決めるんですが、それをスマートなやり方で行いたいです。
3.説明
3.1データの作成
まず、numpy.random()でOK、NGデータを作成します。
ここでは、下記のような条件で作成します。
- OKデータの数:10000個
- NGデータの数: 1000個
- OKデータのスコアの平均:2
- NGデータのスコアの平均:4
- OKデータの標準偏差:0.5
- NGデータの標準偏差: 0.8
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score, precision_recall_curve, confusion_matrix, plot_confusion_matrix
from make_random_data import data_array
from make_plot import drawplots
#variable
number_OK = 1e4
number_NG = 1e3
OK_mean = 2
NG_mean = 4
OK_stdev = 0.5
NG_stdev = 0.8
#1. Data prepation
data = data_array(n_OK=number_OK, n_NG=number_NG, mean_OK=OK_mean, mean_NG=NG_mean, std_OK=OK_stdev, std_NG=NG_stdev)
Score_data,Label_data, OK_score, NG_score = data.make()
描画のためのインスタンスを用意します。
plots = drawplots(Score_data=Score_data, Label_data=Label_data, OK_score=OK_score, NG_score=NG_score)
plots.draw_histogram(threshold_opt=None)
3.2計算 (ROC曲線、AUCスコア、f1スコア)
ROC曲線とAUCスコアを計算します。
#2.ROC Curve
fpr, tpr, threshold = roc_curve(y_true = Label_data, y_score = Score_data)
#2.1.AUC score
roc_auc_value = roc_auc_score(y_true = Label_data, y_score = Score_data)
f1スコアを計算します。
#3.f1 Value
precision, recall, threshold_from_pr = precision_recall_curve(y_true = Label_data, probas_pred = Score_data)
a = 2* precision * recall
b = precision + recall
f1 = np.divide(a,b,out=np.zeros_like(a), where=b!=0)
3.3最適な閾値の計算
f1スコアを最大にするインデクスを計算します。
#4. Optimal Value
#4.1 find optimal threshold
idx_opt = np.argmax(f1)
ここで計算が少しややこしくなりますが、次の計算のために何らかの前処理を行います。
#これがやっかい
threshold_opt = threshold_from_pr[idx_opt] #Confusion Matrix
idx_opt_from_pr = np.where(threshold == threshold_opt) # ROC Curve
4.結果
必要なグラフを描画します。
#5 draw f1 score
plots.draw_f1_score(threshold_from_pr=threshold_from_pr, f1=f1)
#6.draw precision recall curve
plots.draw_precision_recall(precision=precision, recall=recall)
#7. Plot ROC curve
plots.draw_ROC_curve(fpr=fpr, tpr=tpr, opt_idx=idx_opt_from_pr, roc_auc_value=roc_auc_value)
#8. Draw Confusion Matrix
plots.draw_confusion_matrix(Score_data=Score_data, Label_data=Label_data, threshold_opt=threshold_opt)
#09. Draw Histogram with Threshold
plots.draw_histogram(threshold_opt=threshold_opt)
plt.show()
4.1 F1スコア with 最適閾値
F1スコアを最大にするしきい値を最適値にします。この場合、Thresholdは 3.245です。
4.2 ROC曲線 with 最適閾値
4.3 データ分布 with 最適閾値
4.4 混同行列
4.5 Precision - Recall 曲線
5.まとめ
分類問題の課題である「最適なしきい値を探す」方法について説明しました。
6.全体プログラムコード
本体プログラム
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score, precision_recall_curve, confusion_matrix, plot_confusion_matrix
from make_random_data import data_array
from make_plot import drawplots
#variable
number_OK = 1e4
number_NG = 1e3
OK_mean = 2
NG_mean = 4
OK_stdev = 0.5
NG_stdev = 0.8
#1. Data prepation
data = data_array(n_OK=number_OK, n_NG=number_NG, mean_OK=OK_mean, mean_NG=NG_mean, std_OK=OK_stdev, std_NG=NG_stdev)
Score_data,Label_data, OK_score, NG_score = data.make()
plots = drawplots(Score_data=Score_data, Label_data=Label_data, OK_score=OK_score, NG_score=NG_score)
plots.draw_histogram(threshold_opt=None)
print(Score_data)
print(Label_data)
#2.ROC Curve
fpr, tpr, threshold = roc_curve(y_true = Label_data, y_score = Score_data)
#2.1.AUC score
roc_auc_value = roc_auc_score(y_true = Label_data, y_score = Score_data)
#3.f1 Value
precision, recall, threshold_from_pr = precision_recall_curve(y_true = Label_data, probas_pred = Score_data)
a = 2* precision * recall
b = precision + recall
f1 = np.divide(a,b,out=np.zeros_like(a), where=b!=0)
#4. Optimal Value
# Optimal value routine is hard to understand later
#4.1 find optimal threshold
idx_opt = np.argmax(f1)
#これがやっかい
threshold_opt = threshold_from_pr[idx_opt] #Confusion Matrix
idx_opt_from_pr = np.where(threshold == threshold_opt) # ROC Curve
#5 draw f1 score
plots.draw_f1_score(threshold_from_pr=threshold_from_pr, f1=f1)
#6.draw precision recall curve
plots.draw_precision_recall(precision=precision, recall=recall)
#7. Plot ROC curve
plots.draw_ROC_curve(fpr=fpr, tpr=tpr, opt_idx=idx_opt_from_pr, roc_auc_value=roc_auc_value)
#8. Draw Confusion Matrix
plots.draw_confusion_matrix(Score_data=Score_data, Label_data=Label_data, threshold_opt=threshold_opt)
#09. Draw Histogram with Threshold
plots.draw_histogram(threshold_opt=threshold_opt)
plt.show()
make_random_data.py データ生成プログラム
import numpy as np
import matplotlib.pyplot as plt
class data_array():
def __init__(self, n_OK=1e4, n_NG=1e4, mean_OK=2.5, mean_NG=4, std_OK=0.5, std_NG=0.5):
self.n_OK = int(n_OK)
self.n_NG = int(n_NG)
self.mean_OK = mean_OK
self.mean_NG = mean_NG
self.std_OK = std_OK
self.std_NG = std_NG
self.Score_data = None
self.Label_data = None
def make(self):
self._make_score()
self._make_label()
return self.Score_data, self.Label_data, self.OK_score, self.NG_score
def _make_score(self):
self.OK_score = np.random.normal(loc=self.mean_OK, scale=self.std_OK, size=self.n_OK)
self.NG_score = np.random.normal(loc=self.mean_NG, scale=self.std_NG, size=self.n_NG)
self.Score_data = np.concatenate((self.OK_score, self.NG_score))
def _make_label(self):
self.OK_label = np.array([0 for i in range(self.n_OK)])
self.NG_label = np.array([1 for i in range(self.n_NG)])
self.Label_data = np.concatenate((self.OK_label, self.NG_label))
make_plot.py グラフ生成プログラム
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
class drawplots():
def __init__(self, Score_data, Label_data, OK_score, NG_score):
self.Score_data = Score_data
self.Label_data = Label_data
self.OK_score = OK_score
self.NG_score = NG_score
def draw_histogram(self, threshold_opt = None):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.hist(self.OK_score, bins=100, color='blue', alpha=0.3, density=False, label='OK')
ax.hist(self.NG_score, bins=100, color='red', alpha=0.3, density=False, label='NG')
if threshold_opt:
ax.axvline(x=threshold_opt, color='green', linestyle='dashed', linewidth=2)
ax.set_title(f'OK,NG Data Distribution, threshold: {round(threshold_opt, 2)}')
else:
pass
ax.set_title(f'OK,NG Data Distribution')
ax.set_xlabel('Score')
ax.set_ylabel('frequency')
ax.grid()
ax.legend()
return fig
def draw_ROC_curve(self, fpr, tpr, opt_idx, roc_auc_value):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(fpr, tpr, marker='o')
ax.plot(fpr[opt_idx], tpr[opt_idx], marker='o', color='red')
ax.plot(np.linspace(1, 0, len(fpr)), np.linspace(1, 0, len(fpr)), linestyle='--', color='gray')
ax.set_title(f'ROC Curve: {round(roc_auc_value, 3)}')
ax.set_xlabel('FPR: False positive rate')
ax.set_ylabel('TPR: True positive rate')
ax.grid()
ax.legend()
return
def draw_f1_score(self,threshold_from_pr, f1):
#find the optimal
idx_opt = np.argmax(f1)
#make thresholds_new_array
threshold2_new = np.append(threshold_from_pr, threshold_from_pr[-1])
threshold_opt = threshold_from_pr[idx_opt]
#draw
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(threshold2_new, f1)
ax.axvline(x=threshold_opt, color='green', linestyle='dashed', linewidth=2)
ax.plot(threshold2_new[idx_opt], f1[idx_opt], marker='o', color='red')
ax.set_title(f'F1 Score: {round(threshold_opt, 3)}')
ax.set_xlabel('Threshold')
ax.set_ylabel('F1 Score')
ax.grid()
return fig
def draw_confusion_matrix(self, Score_data, Label_data, threshold_opt, ):
#make a predicted label
predicted_label = []
for kk in range(Score_data.shape[0]):
if Score_data[kk] >= threshold_opt:
predicted_label.append(1) # NG
else:
predicted_label.append(0)
# label_class
classes = ['OK', 'NG']
cm = confusion_matrix(y_true=Label_data, y_pred=predicted_label)
cm = cm / cm.astype(np.float64).sum(axis=1) * 100
#plot
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
cmap = plt.cm.Blues
ax.imshow(cm, cmap=cmap)
for m in range(cm.shape[1]):
for n in range(cm.shape[0]):
ax.text(x=n, y=m, s=round(cm[m, n], 2), va='center', ha='center', color="gray")
ax.set_title(f'Confusion Matrix, Optimal Threshold : {round(threshold_opt, 2)}')
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
tick_marks = np.arange(len(classes))
ax.set_xticks(tick_marks)
ax.set_yticks(tick_marks)
ax.set_xticklabels(classes)
ax.set_yticklabels(classes)
return fig
def draw_precision_recall(self, precision, recall):
# plot
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(precision, recall)
ax.set_title(f'Precision Recall Curve')
ax.set_xlabel('Precision')
ax.set_ylabel('Recall')
ax.grid()
return fig