やったこと
今回、機械学習のアルゴリズムk近傍法(k-nn)を用いて手書き数字認識を行った。
データの読み込み
from sklearn.datasets import fetch_openml
mnist = fetch_openml("mnist_784")
sklearnのメソッドfetch_openmkでMNISTから手書き数字データmnist_784を取得する。
mnist_784は
- 訓練データ6万枚、テストデータ1万枚の計7万枚の画像量
- 28(縦)×28(幅)ピクセル
- 8bitグレースケール=(白黒画像)白「0」~黒「255」
の画像データである。今回は学習時間を短縮するために
(訓練データ,テストデータ)=(500,100),(5000,1000)で学習させた。
データの分割
MNISTより取得した画像データを訓練データとテストデータに分割する。割合は8:2とした。
from sklearn.model_selection import train_test_split
import numpy as np
train_size = 5000 #訓練データ
test_size = 1000 #テストデータ
X_train, X_test, y_train, y_test = \
train_test_split(mnist.data, mnist.target, stratify=mnist.target,random_state=66,train_size=train_size,test_size=test_size)
-
train_test_split
はnp配列をランダムに2分割する。-
mnist.data
はデータ数×特徴量の2次元配列である。 -
mnist.target
はmnist.dataに対応するラベルを表す。(mnist.data[0]のラベルはmnist.target.[0]) -
stratify
は層化抽出を行う(同じ割合データを分割)。 -
random_state
は乱数シードを固定できる。固定することで常に同じように分割することが出来る。毎回ランダムに分割されると学習モデル同士の比較が正しく行えなくなる。
-
参考:https://note.nkmk.me/python-sklearn-train-test-split/
学習
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
training_accuracy = []
test_accuracy = []
n_neighbors_settings = range(1,11)
for n_neighbors in n_neighbors_settings:
clf = KNeighborsClassifier(n_neighbors=n_neighbors)
clf.fit(X_train,y_train)
training_accuracy.append(clf.score(X_train, y_train))
test_accuracy.append(clf.score(X_test,y_test))
skelearn
からKNeighborsClassfier
をインポートする。今回はk=1,2,...,10としてk近傍法アルゴリズムを適用する。training_accuracy
は各kに対する訓練データの正解率の配列を表し、test_accuracy
は各kに対するテストデータの正解率の配列表す。学習結果をその配列に加えていく。
実行結果
-
train_size=500,test_size=100の場合
スイートスポットはk=7のときである。また、全体的に精度が悪い。
次にtrain_size=5000,test_size=1000として学習させる。 -
train_size=5000,test_size=1000の場合
スイートスポットはk=5のときである。また、training_size=500,test_size=100のときと比べて精度が全体的に良い。k=5のときテストデータに対する精度は95%であるから十分な精度といえるだろう。
予測する様子
- train_size=5000,test_size=1000のデータセットで学習させたk=5のとしたk-nnである。
#n_neighbor=5としたknnの学習モデル
plt.imshow(X_test[0].reshape(28,28),cmap="gray")
print(f"実際の値:{y_test[0]}")
print(f"予測した値:{clf2.predict(X_test)[0]}")
以上のようにコードを書いて予測しているか見てみる。
おまけ
- matplotlibでのグラフの描画の方法
plt.tick_params(colors="white")
plt.plot(n_neighbors_settings,training_accuracy,label="training_accuracy")
plt.xticks(np.arange(0,11,step=1))
plt.plot(n_neighbors_settings,test_accuracy,label="test_accuracy")
plt.ylabel("accuracy",color="white")
plt.xlabel("n_neighbors",color="white")
plt.legend()
plt.grid()
plt.xticks
はx軸の範囲を指定し、ステップ数を決める
plt.ylabel
はy軸の名前を決め、文字の色を指定
plt.xlabel
はx軸の名前を決め、文字の色を指定
plt.legend
は各グラフの説明を追加
plt.grid
はグリッドを追加(目盛り)