#はじめに
① https://qiita.com/yohiro/items/04984927d0b455700cd1
② https://qiita.com/yohiro/items/5aab5d28aef57ccbb19c
③ https://qiita.com/yohiro/items/cc9bc2631c0306f813b5
④ https://qiita.com/yohiro/items/d376f44fe66831599d0b
⑤ https://qiita.com/yohiro/items/3abaf7b610fbcaa01b9c
の続き
- 参考教材:Udemy みんなのAI講座 ゼロからPythonで学ぶ人工知能と機械学習
- 使用ライブラリ:scikit-learn
#課題設定
手書きの数字画像(8×8 px)から、書いてある数字を認識する。
#ソースコード
##インポート
from sklearn import datasets
from sklearn import svm
from sklearn import metrics
import matplotlib.pyplot as plt
#サンプルデータの読み込み
# 数字データの読み込み
digits = datasets.load_digits()
digitsには、以下のようなデータが入っている。
[[ 0. 0. 5. ... 0. 0. 0.]
[ 0. 0. 0. ... 10. 0. 0.]
[ 0. 0. 0. ... 16. 9. 0.]
...
[ 0. 0. 1. ... 6. 0. 0.]
[ 0. 0. 2. ... 12. 0. 0.]
[ 0. 0. 10. ... 12. 1. 0.]]
[0 1 2 ... 8 9 8]
digits.data
は64×1797のリストで、要素の値はグレースケールにおける色を表しており、一つの64要素リストが一つの画像を表している。画像表示用にdigits.image
にもリスト形式は異なるが同様の情報が入っている。
digits.target
はそれぞれの画像の正解(=どの数字を表しているか)を示している。
##サポートベクターマシンによる訓練
# サポートベクターマシン
clf = svm.SVC(gamma=0.001, C=100.0) # gamma:一つの訓練データが与える影響の大きさ, C:誤認識許容度
# サポートベクターマシンによる訓練(6割のデータを使用、残りの4割は検証用)
clf.fit(digits.data[:int(n*6/10)], digits.target[:int(n*6/10)])
前回使ったのはLinearSVC()
だったが、今回はSVC()
を使用している。
線形の境界線では分類ができないから?
##分類
上記で作成したclfにdigits.dataの残りの4割のデータを読ませ、どの数字になるか、それぞれ分類させる。
# 正解
expected = digits.target[int(-n*4/10):]
# 予測
predicted = clf.predict(digits.data[int(-n*4/10):])
# 正解率
print(metrics.classification_report(expected, predicted))
# 誤認識のマトリックス
print(metrics.confusion_matrix(expected, predicted))
##結果
###正解率
precision recall f1-score support
0 0.99 0.99 0.99 70
1 0.99 0.96 0.97 73
2 0.99 0.97 0.98 71
3 0.97 0.86 0.91 74
4 0.99 0.96 0.97 74
5 0.95 0.99 0.97 71
6 0.99 0.99 0.99 74
7 0.96 1.00 0.98 72
8 0.92 1.00 0.96 68
9 0.96 0.97 0.97 71
accuracy 0.97 718
macro avg 0.97 0.97 0.97 718
weighted avg 0.97 0.97 0.97 718
0と予測したものは99%が正解、正解が0だった内正しく0と予想されたものは99%、のように読む。
表の読み方の参考:
###誤認識マトリックス
[[69 0 0 0 1 0 0 0 0 0]
[ 0 70 1 0 0 0 0 0 2 0]
[ 1 0 69 1 0 0 0 0 0 0]
[ 0 0 0 64 0 3 0 3 4 0]
[ 0 0 0 0 71 0 0 0 0 3]
[ 0 0 0 0 0 70 1 0 0 0]
[ 0 1 0 0 0 0 73 0 0 0]
[ 0 0 0 0 0 0 0 72 0 0]
[ 0 0 0 0 0 0 0 0 68 0]
[ 0 0 0 1 0 1 0 0 0 69]]
0の画像の内、0と認識されたものが69件、4と認識されたものが1件、のように読む。
##実際の画像と予測値
# 予測と画像の対応(一部)
images = digits.images[int(-n*4/10):]
for i in range(12):
plt.subplot(3, 4, i + 1)
plt.axis("off")
plt.imshow(images[i], cmap=plt.cm.gray_r, interpolation="nearest")
plt.title("Guess: " + str(predicted[i]))
plt.show()
数字を認識できていることがわかる。
#おまけ
digits.dataを可視化してみた(白黒の2値画像)
for i in range(10):
my_s = ""
for k, j in enumerate(digits.data[i]):
if (j > 0):
my_s += " ■ "
else:
my_s += " "
if k % 8 == 7:
print(my_s)
my_s = ""
print("\n")
結果
■ ■ ■ ■
■ ■ ■ ■ ■
■ ■ ■ ■ ■
■ ■ ■ ■
■ ■ ■ ■
■ ■ ■ ■ ■
■ ■ ■ ■ ■
■ ■ ■
■ ■ ■
■ ■ ■
■ ■ ■ ■
■ ■ ■ ■ ■
■ ■ ■ ■
■ ■ ■ ■
■ ■ ■ ■
■ ■ ■
...
■ ■ ■ ■
■ ■ ■ ■
■ ■ ■ ■
■ ■ ■ ■ ■
■ ■ ■ ■
■ ■ ■ ■ ■ ■
■ ■ ■ ■ ■ ■
■ ■ ■ ■ ■
■ ■
■ ■ ■ ■ ■
■ ■ ■ ■ ■
■ ■ ■ ■ ■
■ ■ ■ ■ ■
■ ■ ■
■ ■ ■
■ ■ ■ ■
なんとなく手書き文字になっていることがわかる