突然ですが、私はリアルでは字が汚いと言われることがあります。どれくらい汚いかと言われると、大学院生時代に生協で備品を注文しようとした時に個数を数字で書くと職員さんが読めないレベルで字が汚いらしいです(自分は達筆なのだと思っています)。
ところで機械学習で字といえばMNISTを想像される方が多いと思いますが、あれって実際機械学習を勉強するにしては字がきれいなのかとふと疑問に思うことがあります。生データを見てみると「え、読めない」と思うことがあるのですが、じゃあ機械学習を使ってみるとどうなるのかちょっと覗いてみましょう。
コーディング
ライブラリのインポート
from sklearn.datasets import load_digits
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split as tts
import numpy as np
import matplotlib.pyplot as plt
データの読み込み
今回は今までに書いた記事と違ってScikit-Learnのトイデータをそのまま使わせてもらいます。
x = load_digits().data
y = load_digits().target
x = x / max(x.flatten())
一応前処理として大丈夫だとは思いますが正規化を行って取り得る値を0から1までに直します。
学習とスコア
では実際に学習させてみてスコアを見てみましょう。
まずデータは全て同じになるようにテストデータのサイズとランダムにテストデータと訓練データを分けるものは指定します。
ニューラルネットワークの中間層は500、1000、750、500にします(いつもだいたいこの数字)。
x_train, x_test, y_train, y_test = tts(x, y, test_size=0.2, random_state=1)
model = MLPClassifier(hidden_layer_sizes=(500, 1000, 750, 500))
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
model.score(x_test, y_test)
結果正解率は0.9861111111111112となりました。
間違えたデータを探す
精度は結果として100%ではありませんでした、つまりどこかしら数字を誤読したという事が考えられます。
ではどの数字をどのように間違えて、なおかつSoftmax関数の値からどのくらいの確率でそう答えたのか見てみましょう。
miss = []
y_predSM = model.predict_proba(x_test)
for i in range(len(y_test)):
if y_test[i] != y_pred[i]:
miss.append([i, y_test[i], y_pred[i], max(y_predSM[i])])
df_miss = pd.DataFrame(miss)
df_miss.columns = ["ID", "real", "predict", "probably"]
df_miss
では一番上に表示されているID108番の文字について画像で出力してみましょう。
plt.imshow(x_test[108].reshape(8, 8), cmap="gray")
plt.title("real=9 predict=5 probably=0.99")
plt.show()
確かに9とも5とも見て取れますね。ぶっちゃけこれだけ見せられたら答えに戸惑います。
まとめ
というわけで今回はMNISTの誤読をテーマにしてみましたが、これは実際にはある意味ややこしいデータを見つけることもできます。そのややこしいデータを見つけてどうするかは作り手次第ですが、統計的にややこしい、画像がややこしい等、分類アルゴリズムは100%当たる訳でもありませんし、何%の確率でそう予測したかも出てきますのでMNISTに限らず様々な発見につながると思います。