データ分析の結果を精度を高めるためにどこで誤分類が起きたのかを特定しよう
というのが今回のテーマです。
そこで、今日は confusion matrix を使って、どこで誤分類が起きたのかを可視化して見ていていきます。
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix
clf = DecisionTreeClassifier()
clf.fit(X_train, Y_train)
result = clf.predict(X_test)
cm = confusion_matrix(Y_test, result)
print(cm)
iris データセットを使うと下記の図のように可視化されます。
Extracted from sklearn 公式ドキュメント
少し小さくて、見えにくいかもしれませんが、y軸が True value、つまり、正解のラベリングで、x軸が、Predicted value で、機械学習モデルを使用してラベリングされたものです。 上記の図で見ると、中央行、右のところで誤分類が起きています。
これを認識した上で、データの前処理を見直してみたり、機械学習モデルのパラメータを調整しなおしてみることで、精度が上がるかもしれませんね。