3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

pythonによるk-近傍法(kNN)による分類(【高等学校情報科 情報Ⅱ】教員研修用教材)

Posted at

はじめに

k-近傍法(k-nearest neighbor method,kNN)とは、予測したい値に最も距離が近いk個の中で多数決を取り、予測値を決めるシンプルな機械学習アルゴリズムです。
以下の説明が直感的でわかりやすいです。

右の図では,スマイルマークが予測したい値の位置で,k=3としたときの近傍の範囲を表している。この場合は,◆がスマイルマークの予測値となる。
SnapCrab_NoName_2020-9-1_18-58-0_No-00.png

この記事では教材にてk-近傍法による分類について、Rにより実装されている部分をpythonに書き換えていきたいと思います。

教材

高等学校情報科「情報Ⅱ」教員研修用教材(本編):文部科学省
第3章 情報とデータサイエンス 後半 (PDF:7.6MB)

#環境

教材内で取り上げる箇所

学習15 分類による予測:「3.k-近傍法による分類」

今回取り扱うデータ

教材と同じように、kaggleからdigit-recognizerのデータをダウンロードします。
使用するのは「train.csv」です。

pythonでの実装例と結果

訓練データと試験データの読み込み

train.csvには、42,000個の手書き数字の情報が格納されており、1つの手書き数字の情報は1列目(label)に正解ラベル(正解の数字)、2列目以降(pixel)に784(28×28)個のピクセルの256段階のグレースケールの階調値(0-255)がされているような形となっています。

ここでは、訓練データとして先頭から1,000個、テストデータとして次の100個のデータを使用します。

import numpy as np
import pandas as pd
from IPython.display import display

mnist = pd.read_csv('/content/train.csv')

mnist_light = mnist.iloc[:1000,:]
mnist_light_test = mnist.iloc[1000:1100,:]

# 訓練データ
Y_mnist_light = mnist_light[['label']].values.ravel()
#display(Y_mnist_light)
X_mnist_light = mnist_light.drop('label', axis = 1)
#display(X_mnist_light)

# テストデータ
Y_mnist_light_test = mnist_light_test[['label']].values.ravel()
#display(Y_mnist_light_test)
X_mnist_light_test = mnist_light_test.drop('label', axis = 1)
#display(X_mnist_light_test)

訓練データの学習と予測

k=3としたときのk-近傍法と訓練データにより訓練させたのち、テストデータ100個から予測値を取得します。
テストデータのlabel(正解値)と比較し、正答率を表示します。

from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics
from sklearn.metrics import accuracy_score

# sklearn.neighbors.KNeighborsClassifier使用
knn = KNeighborsClassifier(n_neighbors = 3)
knn.fit(X_mnist_light, Y_mnist_light)

# 予測実行
pred_y = knn.predict(X_mnist_light_test)
display(pred_y)

#正答確認
result_compare = pred_y == Y_mnist_light_test
display(result_compare)

minist_accuracy_score = accuracy_score(Y_mnist_light_test, pred_y)

#正答率
print(minist_accuracy_score)

実行結果は以下になります。

array([1, 5, 1, 7, 9, 8, 9, 5, 7, 4, 7, 2, 8, 1, 4, 3, 8, 6, 2, 7, 2, 6,
       7, 8, 1, 8, 8, 1, 9, 0, 9, 4, 6, 6, 8, 2, 3, 5, 4, 5, 4, 1, 3, 7,
       1, 5, 0, 0, 9, 5, 5, 7, 6, 8, 2, 8, 4, 2, 3, 6, 2, 8, 0, 2, 4, 7,
       3, 4, 4, 5, 4, 3, 3, 1, 5, 1, 0, 2, 2, 2, 9, 5, 1, 6, 6, 9, 4, 1,
       7, 2, 2, 0, 7, 0, 6, 8, 0, 5, 7, 4])
array([ True,  True,  True,  True, False,  True,  True,  True,  True,
        True,  True,  True, False,  True,  True, False,  True,  True,
        True, False,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True, False, False,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True, False,  True, False,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True, False,
        True,  True,  True,  True,  True,  True,  True,  True, False,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True, False,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True])
0.89

正答率は0.89となっています。

実際に誤った手書き数字の確認

テストデータのうち、5つ目のテストデータが誤って認識されています。
実際に誤った数字の確認を行っています。

import matplotlib.pyplot as plt

# テストデータの表示
fig, axes = plt.subplots(2, 5)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1.0, hspace=0.1, wspace=0.1)
for i in range(2):
    for j in range(5):
        axes[i, j].imshow(X_mnist_light_test.values[i*5+j].reshape((28, 28)), cmap='gray')
        axes[i, j].set_xticks([])
        axes[i, j].set_yticks([])
plt.show()

実行結果は以下になります。

SnapCrab_NoName_2020-9-1_19-16-22_No-00.png

上の段の一番右の手書き数字は、label(正解値)は4でしたが予測値は9として誤認識したようです。目視でも9と判定しかねない数字にみえます。

混合行列とkの値を変えたときの正答率の変化

縦軸が予測値、横軸が正解ラベルとして、その個数を示す表を表示します。

from sklearn.metrics import confusion_matrix

cfm = confusion_matrix(Y_mnist_light_test, pred_y)

print(cfm)

実行結果は以下になります。

[[ 7  0  0  0  0  0  0  0  0  0]
 [ 0 10  0  0  0  0  0  0  0  0]
 [ 0  0 13  0  0  0  0  1  1  0]
 [ 0  0  0  5  0  1  0  0  0  0]
 [ 0  0  0  0 11  0  0  0  0  1]
 [ 0  0  0  0  0 10  0  0  0  0]
 [ 0  0  0  0  0  0  9  0  0  0]
 [ 1  0  0  0  0  0  0 10  0  1]
 [ 0  0  0  2  0  0  0  0 10  1]
 [ 0  1  0  0  1  0  0  0  0  4]]

次にkの値は何が適していたのかを見るためにkの値を変えたときに正答率をグラフで表示します。

n_neighbors_chg_list = []

# n_neighborsを変化させたときのグラフ
for i in range(1,100):
    # sklearn.neighbors.KNeighborsClassifier使用
    knn_temp = KNeighborsClassifier(n_neighbors = i)
    knn_temp.fit(X_mnist_light, Y_mnist_light)

    # 予測実行
    pred_y_temp = knn_temp.predict(X_mnist_light_test)

    #正解率
    minist_accuracy_score_temp = accuracy_score(Y_mnist_light_test, pred_y_temp)

    #配列に格納
    n_neighbors_chg_list.append(minist_accuracy_score_temp)

plt.plot(n_neighbors_chg_list)

実行結果は以下になります。

ダウンロード (12).png

一般的に、kの値が大きければ外れ値があっても結果に与える影響は少ないためノイズの影響を低減できますが、クラスの境界が明確にならない傾向にあります。
訓練データの数などで適切なkの値は変わってくるが、今回の試行ではkが大きくなるほど正答率は低くなる傾向が見えました。

#ソースコード

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?