2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

手書き数字認識(k-nn)

Last updated at Posted at 2022-04-12

やったこと

今回、機械学習のアルゴリズムk近傍法(k-nn)を用いて手書き数字認識を行った。

データの読み込み

from sklearn.datasets import fetch_openml

mnist = fetch_openml("mnist_784")

sklearnのメソッドfetch_openmkでMNISTから手書き数字データmnist_784を取得する。
mnist_784は

  • 訓練データ6万枚、テストデータ1万枚の計7万枚の画像量
  • 28(縦)×28(幅)ピクセル
  • 8bitグレースケール=(白黒画像)白「0」~黒「255」

の画像データである。今回は学習時間を短縮するために
(訓練データ,テストデータ)=(500,100),(5000,1000)で学習させた。

データの分割

MNISTより取得した画像データを訓練データとテストデータに分割する。割合は8:2とした。

from sklearn.model_selection import train_test_split
import numpy as np
train_size = 5000 #訓練データ
test_size = 1000 #テストデータ

X_train, X_test, y_train, y_test = \
    train_test_split(mnist.data, mnist.target, stratify=mnist.target,random_state=66,train_size=train_size,test_size=test_size)
  • train_test_splitはnp配列をランダムに2分割する。
    • mnist.dataはデータ数×特徴量の2次元配列である。
    • mnist.targetはmnist.dataに対応するラベルを表す。(mnist.data[0]のラベルはmnist.target.[0])
    • stratifyは層化抽出を行う(同じ割合データを分割)。
    • random_stateは乱数シードを固定できる。固定することで常に同じように分割することが出来る。毎回ランダムに分割されると学習モデル同士の比較が正しく行えなくなる。

参考:https://note.nkmk.me/python-sklearn-train-test-split/

学習

from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt

training_accuracy = []
test_accuracy = []

n_neighbors_settings = range(1,11)

for n_neighbors in n_neighbors_settings:
    clf = KNeighborsClassifier(n_neighbors=n_neighbors)
    clf.fit(X_train,y_train)

    training_accuracy.append(clf.score(X_train, y_train))
    test_accuracy.append(clf.score(X_test,y_test))

skelearnからKNeighborsClassfierをインポートする。今回はk=1,2,...,10としてk近傍法アルゴリズムを適用する。training_accuracyは各kに対する訓練データの正解率の配列を表し、test_accuracyは各kに対するテストデータの正解率の配列表す。学習結果をその配列に加えていく。

実行結果

  • train_size=500,test_size=100の場合
    image.png
    スイートスポットはk=7のときである。また、全体的に精度が悪い。
    次にtrain_size=5000,test_size=1000として学習させる。

  • train_size=5000,test_size=1000の場合
    image.png
    スイートスポットはk=5のときである。また、training_size=500,test_size=100のときと比べて精度が全体的に良い。k=5のときテストデータに対する精度は95%であるから十分な精度といえるだろう。

予測する様子

  • train_size=5000,test_size=1000のデータセットで学習させたk=5のとしたk-nnである。
#n_neighbor=5としたknnの学習モデル
plt.imshow(X_test[0].reshape(28,28),cmap="gray")
print(f"実際の値:{y_test[0]}")
print(f"予測した値:{clf2.predict(X_test)[0]}")

以上のようにコードを書いて予測しているか見てみる。

image.png
結果は以上のようになり正しく予測できているといえる。

おまけ

  • matplotlibでのグラフの描画の方法
plt.tick_params(colors="white")
plt.plot(n_neighbors_settings,training_accuracy,label="training_accuracy")
plt.xticks(np.arange(0,11,step=1))
plt.plot(n_neighbors_settings,test_accuracy,label="test_accuracy")
plt.ylabel("accuracy",color="white")
plt.xlabel("n_neighbors",color="white")
plt.legend()
plt.grid()

plt.xticksはx軸の範囲を指定し、ステップ数を決める
plt.ylabelはy軸の名前を決め、文字の色を指定
plt.xlabelはx軸の名前を決め、文字の色を指定
plt.legendは各グラフの説明を追加
plt.gridはグリッドを追加(目盛り)

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?