ふと思ったんですがMNISTは8×8で次元数(64次元)の少ないデータセットだけど可視化でき、そしてそれが数字を表していると。
じゃあAIはどこを見て文字を判断しているか簡単に算出して可視化できるのではないかと思ってやってみました。
from sklearn.tree import DecisionTreeClassifier as DTC
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
x = load_digits().data
y = load_digits().target
model = DTC()
model.fit(x, y)
imp = model.feature_importances_
imp = (imp - min(imp)) / (max(imp) - min(imp))
imp = imp.reshape(8, 8)
plt.imshow(imp, cmap="gray")
plt.colorbar()
plt.show()
というわけで可視化したらこのようになりました。
白くなればなるほどAIは重要視しているということになります。
きっとこれ0~9の共通している部分がグレーで白が何かを区別するところなんでしょう。
とりあえず横端は関係なく真ん中らへんと上と下で判断材料があるということでしょうね。
字を意識したことなかったのでこれから文字書くとき気を付けます。