LoginSignup
4
4

More than 3 years have passed since last update.

多クラス分類の結果の混同行列の集計、評価指標の計算

Last updated at Posted at 2019-11-07

多クラス分類で、未知データを入力とした際、各クラスへの分類確率の予測値が算出される。

テストデータを用いた評価の際、混同行列を書く必要があるので、この記事では、その書き方を例を示す。

データの用意

3つのクラスに対して、5つのデータを分類器に入力した結果を results、正解を y_true とする。

results = [
    [0.1, 0.2, 0.7],
    [0.5, 0.3, 0.3],
    [0.2, 0.3, 0.2],
    [0.2, 0.7, 0.2],
    [0.2, 0.5, 0.2]
]
y_true = [
    [0, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 0, 1],
    [0, 1, 0]
]

混同行列を出す

利用するものとしては、 sklearn.metrics.confusion_matrix を使う。

sklearn.metrics.confusion_matrix

この入力に合う形にデータを整形する。

import numpy as np

results2 = [np.argmax(i) for i in results]
# [2, 0, 1, 1, 1]

y_true2 = [np.argmax(i) for i in y_true]
# [2, 0, 2, 2, 1]

よって以下のように集計できる。

from sklearn.metrics import confusion_matrix
confusion_matrix(y_true2, results2)

# array([[1, 0, 0],
#        [0, 1, 0],
#        [0, 2, 1]])

seabornで可視化する

seabornを使うことで、sklearn.metrics.confusion_matrixの結果を美しく可視化できるのでやってみよう。

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

labels = sorted(list(set(y_true2)))

df = pd.DataFrame(confusion_matrix(y_true2, results2), index=labels, columns=labels)

plt.figure()
sns.heatmap(df, annot=True) # annot = Falseにすると図中の各セルの数値表示がなくなる。
plt.show()

スクリーンショット 2019-11-07 9.47.07.png

↑ 図がずれているのはこちらのバグ?のようだ。

おまけ

results を, y_true と同様に、binaryで表示すると、 かつての記事でも書いた sklearn.metrics.classification_report が利用できる。これにより、各種評価指標を一度に計算してくれる。

まずは変換。

b_results = []
for i in results:
    max_index = np.argmax(i)
    tmp = [0] * max(map(len, result))
    tmp[max_index] = 1
    b_results.append(tmp)
# [[0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0]]

sklearn.metrics.classification_report を利用する。

from sklearn.metrics import classification_report

print(classification_report(np.array(y_true), np.array(b_results)))

出力結果。

              precision    recall  f1-score   support

           0       1.00      1.00      1.00         1
           1       0.33      1.00      0.50         1
           2       1.00      0.33      0.50         3

   micro avg       0.60      0.60      0.60         5
   macro avg       0.78      0.78      0.67         5
weighted avg       0.87      0.60      0.60         5
 samples avg       0.60      0.60      0.60         5
4
4
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4