3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

k-NNで手書き文字を認識するために、kの値を交差確認法で求めた。

Posted at

ソースコードをGitHubにあげています。

手書き文字を最近傍識別器によって識別したいのですが、kの値をいくつにすればいいのだろう?
そんな時は、交差確認法で最適なkを求めましょう。

用意した手書き文字は、16×16の大きさで、訓練データが5000、テストデータが2000あります。
交差確認法では、訓練データを5分割し、k=i(1<=i<=5)の時の誤識別数を求め、一番良かったkでテストデータを識別します。

以下、コードを順を追ってざっくり説明します。

訓練データ、テストデータを読み込む。

import csv
import numpy as np

train_data = []
test_data = []

for i in range(10):
    train_file = open("digit/digit_train%d.csv" % i)
    train_f = csv.reader(train_file)

    for row in train_f:
        train_data.append([i, [float(s) for s in row]])

    test_file = open("digit/digit_test%d.csv" % i)
    test_f = csv.reader(test_file)

    for row in test_f:
        test_data.append([i, [float(s) for s in row]])

ラベルを予測する関数を定義する

k近傍内に最大予測のラベルが2つ以上あったら、本当は一番近くにあったラベルを採用したいのですが、現状のコードでは、sortのせいで大きな方のラベルが採用されています(ランダムに近い)。要改善。。。

import collections

def get_nearest_label(data, train_data, k):
    distance_list = []
    for i in range(len(train_data)):
        distance = 0
        for j in range(len(train_data[i][1])):
            distance += (data[1][j] - train_data[i][1][j]) ** 2

        distance_list.append(distance)

    label_list = []
    min_index = 0
    for i in range(k):
        min_index = distance_list.index(min(distance_list))
        label_list.append(train_data[min_index][0])
        distance_list[min_index] += 10000

    counter = collections.Counter(label_list)
    sorted_counter = [(v, k) for k, v in counter.items()]
    sorted_counter.sort()
    estimated_label = sorted_counter[len(sorted_counter)-1][1]

    return estimated_label

kごとの誤識別数を出力する

for i in range(len(k_list)):
    k = k_list[i]
    num_error = 0

    for j in range(t):
        excluded_train_data = train_data[train_data_slice_range*j:train_data_slice_range*(j+1)]
        sliced_train_data = train_data[:train_data_slice_range*j] + train_data[train_data_slice_range*(j+1):]

        for l in range(len(excluded_train_data)):                
            label = get_nearest_label(excluded_train_data[l], sliced_train_data, k)

            if label != excluded_train_data[l][0]:
                num_error += 1

    print("識別誤り率 : %d" % (num_error*100/len(train_data)))

上記の結果は、
k=1 : num_error=143
k=2 : num_error=195
k=3 : num_error=164
k=4 : num_error=183
k=5 : num_error=177

よって、k=1を採用します。

テストデータの識別

estimate_matrix = np.zeros([10, 10])
    num_error = 0

    for i in range(len(test_data)):
        estimated_label = get_nearest_label(test_data[i], train_data, 1)
        if estimated_label != test_data[i][0]:
            num_error += 1

        estimate_matrix[test_data[i][0]][estimated_label] += 1

    print(estimate_matrix)
    print("誤識別数 : %d" % num_error)

出力は以下になります。

[[ 198.    0.    1.    1.    0.    0.    0.    0.    0.    0.]
 [   0.  200.    0.    0.    0.    0.    0.    0.    0.    0.]
 [   0.    0.  193.    1.    0.    0.    0.    1.    4.    1.]
 [   0.    0.    0.  195.    0.    3.    0.    0.    1.    1.]
 [   0.    0.    0.    0.  191.    1.    2.    0.    0.    6.]
 [   2.    0.    3.    4.    0.  187.    0.    1.    1.    2.]
 [   1.    0.    2.    0.    0.    2.  195.    0.    0.    0.]
 [   0.    0.    0.    1.    2.    0.    0.  192.    2.    3.]
 [   3.    0.    1.    4.    1.    3.    0.    0.  186.    2.]
 [   0.    0.    0.    0.    3.    0.    0.    1.    1.  195.]]
誤識別数 : 68

前回のFisherの線形判別法で認識するより、ずっと精度は上がりました!

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?