概要
テストデータ(1000件)を用いて、再学習したdistilbertの性能測定を行う。
対象読者
- 人工知能・機械学習・深層学習の概要は知っているという方
- 理論より実装を重視する方
- Pythonを触ったことがある方
目次
再学習モデルの予測性能を計測する指標
機械学習モデルの予測性能測定でよく使われる指標はAccuracy(正解率), Recall(再現率), Precision(適合率), F値(F-measure/F1スコア)がある。
それぞれの定義を説明するにはまずTP(True Positive)/FP(False Positive)/TN(True Negative)/FP(False Positive)を理解する必要がある。これらの指標を理解するためには実例が
欠かせない。
実例:新型コロナのPCR検査
新型コロナの感染者もしくは非感染者が、PCR検査を受けた結果が出たというシチュエーションを考える。PCR検査も100%の精度ではないため、新型コロナ患者なのに感染していない(陰性)、となることもあれば、感染していないのに陽性となることもある。TP/FP/TN/FPは以下のとおりとなる。
- 新型コロナの感染者が陽性と判定された:TP
- 新型コロナの感染者が陰性と判定された:FN
- 新型コロナの非感染者が陽性と判定された:FP
- 新型コロナの非感染者が陰性と判定された:TN
- | 感染者(Positive) | 非感染者(Negative) |
---|---|---|
PCR検査で陽性と予測(Positive) | TP | FP |
PCR検査で陰性と予測(Negative) | FN | TN |
以上の表を常に参照すれば分かるのだが、検定試験などでは暗記しないといけないケースもある。以下の2つの順に考え、1.と2.を連結すると覚えやすい。
- 正解(=実際に感染しているのか)と予測(=PCR検査結果)があっていればTrue, 誤っていればFalseとなる。
- 予測結果(=PCR検査結果)が陽性ならPositive, 陰性ならNegativeを後ろにつける。
たとえば、非感染者が陰性と判定された場合、1.はTrueとなり、2.は陰性なのでNegativeとなる。
よって、True Negativeを示す。
上記の表を混合マトリクス(confusion matrix)と呼ぶ。上記表のTP/FP/TN/FNを利用することでAccuracy(正解率), Recall(再現率), Precision(適合率), F値(F-measure/F1スコア)を説明できる。
Accuracy(正解率)
読んで字のごとく、全体のうちどれだけ正解があるかだ。
\frac{TP+TN}{TP+TN+FP+FN}
Recall(再現率)
再現よりも回収率だとわかりやすいと賢者がQiita記事で教えてくれた。
つまり、少々物騒な物言いだが、感染者をどれだけPCR検査で回収できるかということだ。
\frac{TP}{TP+FN}
Precision(適合率)
陽性と予測したもののみに絞った正答率のことを適合率と呼ぶ。英語でPreciseは精密な、という意味があることからも精度のように厳密に予測の性能を図ることができる。再現率とはトレードオフの関係にある。
\frac{TP}{TP+FP}
F値(F-measure/F1スコア)
Recall(再現率)とPrecision(適合率)はトレードオフの関係にあることから、これらの調和平均をとって性能指標とすることが多い。これをF値(F-measure/F1スコア)と呼ぶ。
\frac{2\times{Recall}\times{Precision}}{Recall + Precision}
以上で準備は完了。
11本目:正解率(Accuracy)
-
質問:8本目で作成したtest_datasetと10本目で再学習したdistilbertモデルを用いて、推論を行い、正解率(Accuracy)を計測せよ。
-
回答:
推論コード
# 重みを更新しないためにevaluaton(eval)モードにする。
model.eval()
l = []
for test_text in test_texts:
# エンコード
input_tokens = tokenizer([test_text], truncation=True, padding=True)
# 推論
outputs = model(torch.tensor(input_tokens['input_ids']).to(device))
# 結果をcpuに転送しリストに追加
l.append(torch.argmax(outputs['logits'], axis=1).item())
Accuracy(正解率)計算コード
correct_cnt = 0
for pred, ans in zip(l, test_labels):
if pred == ans:
correct_cnt += 1
# 正解率(=全データのうち正解がどれだけか)
print(correct_cnt/len(test_labels))
- 結果:
0.869
12本目:再現率・適合率・F値(F1-Score)
-
質問:11本目で推論した結果を用いて、再現率・適合率・F値(F1-Score)を計測せよ。
-
回答:
まずはTP/TN/FP/FNを計算する。
極性判定のネガティブが1でポジティブが0なのが若干紛らわしいので注意
tp = 0
tn = 0
fp = 0
fn = 0
for pred, ans in zip(l, test_labels):
if pred == ans:
if pred ==1:
tp += 1
else:
tn += 1
elif pred == 0:
fn += 1
elif pred == 1:
fp +=1
Recallはこちら。
recall = tp/(tp+fn)
print(recall)
*結果
0.8861788617886179
Precisionはこちら。
precision = tp/(tp+fp)
print(precision)
*結果
0.8532289628180039
F1スコアはこちら。
2 * precision * recall / (precision + recall)
0.8693918245264207
13本目:Classification Report
-
質問:scikit-learnのモジュールを用いてClassification Reportを作成し、12本目で計算した再現率・適合率・F値(F1-Score)と比較せよ。
-
回答:
なんと一行で書けるのだ。sklearnは使わない手はない。ただ、中身を知っていないと結果解釈できないので、今回は丁寧に解説してみました。
from sklearn.metrics import classification_report
print(classification_report(test_labels, l))
- 結果
参考文献
-
[【入門者向け】機械学習の分類問題評価指標解説(正解率・適合率・再現率など)]("
https://qiita.com/FukuharaYohei/items/be89a99c53586fa4e2e4")
著者
ツイッターでPython/numpy/pandas/pytorch関連の有益なツイートを配信してます。