21
29

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

k近傍法(k-NN)の簡単なPython実装

Last updated at Posted at 2019-11-04

k近傍法の概要

k近傍法(k-NearistNeighbor)は機械学習のアルゴリズムで教師あり学習で、分類問題などに使います。とても似た名前にk平均法(k-means)があるりますが、k-meansは教師なし学習でクラスタリングに使います。
k近傍法のアルゴリズム自体はとても単純です。分類したいデータと、既存のデータとの距離を計算し、距離が近いk点のデータの多数決でクラスを決定します。
例えばk=1のときは距離が一番近いデータの仲間である、とするだけです。
図を見たほうがわかりやすいですね。
220px-KnnClassification.svg.png

k近傍法の例。標本(緑の丸)は、第一のクラス(青の四角)と第二のクラス(赤の三角)のいずれかに分類される。k = 3 なら、内側の円内にあるオブジェクトが近傍となるので、第二のクラスに分類される(赤の三角の方が多い)。しかし、k = 5 なら、それが逆転する。
引用:wikipedia

以上からk近傍法の特徴として以下が挙げられます。

  • 最も単純な教師なし学習である
  • 事前に識別のための学習を必要としない
  • 判別時に全データとの距離を計算するため、判別に時間がかかる
  • 次元が大きくなると、計算コスト的にも精度的にも問題がある(いわゆる次元の呪いにハマりやすい性質がある)

k近傍法の実装(Python)

簡単な実装

import numpy as np
import scipy.spatial.distance as distance
import scipy.stats as stats


import numpy as np
import scipy.spatial.distance as distance
import scipy.stats as stats


class knn:
    def __init__(self,k):
        self.k = k
        self._fit_X = None  # 既存データを格納 
        self.classes = None  #
        self._y = None
    def fit(self, X, label):
        # Xは元のデータ点で、shape(data_num, feature_num)
        print("original data:\n", X)
        print("label:\n", label)
        self._fit_X = X
        # ラベルデータからクラスを抽出、またラベルをindexとした配列を作成
        # self.classes[self.label_indices] == label のように復元できるのでreturn_inverseという
        self.classes, self.label_indices = np.unique(label, return_inverse=True)
        print("classes:\n", self.classes)
        print("label_indices:\n", self.label_indices)
        print("classes[label_indices]で復元されるか確認:\n", self.classes[self.label_indices])
    def neighbors(self, Y):
        # Yは予測対象のデータ点(複数可)で, shape(test_num, feature_num) 
        # 予測対象の点とデータ点の距離を求めるので、test_num * data_num だけ距離を計算する
        dist = distance.cdist(Y, self._fit_X)
        print("テストデータと元データとの距離:\n", dist)
        # distはshape(test_num, data_num) となる
        # [[1.41421356 1.11803399 2.6925824  2.23606798]   テスト点1と元データ各点との距離
        #  [3.         2.6925824  1.80277564 1.41421356]   テスト点2と元データ各点との距離
        #  [3.31662479 3.20156212 1.11803399 1.41421356]]  テスト点3と元データ各点との距離
        
        # 距離を測定したらk番目までに含まれるindexをもとめる
        # argpartitionはk番目までと、それ以降にデータを分ける関数
        # argsortだと距離の順位もわかるが、素のk-nnでは距離順位の情報はいらないので、argpartitionを使う
        neigh_ind = np.argpartition(dist, self.k)
        # neigh_indのshapeは(test_num, feature_num)となる
        # 上のdistでk=2でargpartitionしたときの結果
        # 例えば1行目だと index 2,1 が上位2要素になっている。上の距離をみると、0.5と1.5が相当する
        # 2行目だと index 3, 2 が上位2要素で、1.73と1.80が相当する
        #[[1 0 3 2]
        # [3 2 1 0]
        # [2 3 1 0]]
        # k番目までの情報だけを取り出す
        neigh_ind = neigh_ind[:, :self.k]
        # neigh_indのshapeは(test_num, self.k)となる
        #[[1 0]   テスト点1に近い元データ点のindexのリスト
        # [3 2]   テスト点2に近い元データ点のindexのリスト
        # [2 3]]  テスト点3に近い元データ点のindexのリスト
        return neigh_ind
    def predict(self, Y):
        # k番目までのindexを求める shape(test_num, self.k)となる
        print("test data:\n",Y)
        neigh_ind = self.neighbors(Y)
        # stats.modeでその最頻値を求める. shape(test_num, 1) . _は最頻値のカウント数
        # self.label_indices は [0 0 1 1] で、元データの各点のラベルを表す
        # neigh_indは各テスト点に近い元データのindexのリストで shape(est_num, k)となる
        # self.label_indices[neigh_ind] で、以下のような各テスト点に近いラベルのリストを取得できる
        # [[0 0]  テスト点1に近い元データ点のラベルのリスト
        #  [1 1]  テスト点2に近い元データ点のラベルのリスト
        #  [1 1]] テスト点3に近い元データ点のラベルのリスト
        # 上記データの行方向(axis=1)に対してmode(最頻値)をとり、各テスト点が属するラベルとする
        # _はカウント数
        mode, _ = stats.mode(self.label_indices[neigh_ind], axis=1)
        # modeはaxis=1で集計しているのでshape(test_num, 1)となるので、ravel(=flatten)してやる
        # [[0]
        #  [1]
        #  [1]]
        # なおnp.intpはindexに使うデータ型
        mode = np.asarray(mode.ravel(), dtype=np.intp)
        print("test dataの各ラベルindexの最頻値:\n",mode)
        # index表記からラベル名表記にする. self.classes[mode] と同じ
        result = self.classes.take(mode)
        return result

予測してみる

K = knn(k=2)
# 元のデータとラベルをセット
samples = [[0., 0., 0.], [0., .5, 0.], [1., 2., -2.5],[1., 2., -2.]]
label = ['a','a','b', 'b']
K.fit(samples, label)
# 予測したいデータ
Y = [[1., 1., 0.],[2, 2, -1],[1, 1, -3]]
p = K.predict(Y)
print("result:\n", p)

実行結果

>>result
original data:
 [[0.0, 0.0, 0.0], [0.0, 0.5, 0.0], [1.0, 2.0, -2.5], [1.0, 2.0, -2.0]]
label:
 ['a', 'a', 'b', 'b']
classes:
 ['a' 'b']
label_indices:
 [0 0 1 1]
classes[label_indices]で復元されるか確認:
 ['a' 'a' 'b' 'b']
test data:
 [[1.0, 1.0, 0.0], [2, 2, -1], [1, 1, -3]]
テストデータと元データとの距離:
 [[1.41421356 1.11803399 2.6925824  2.23606798]
 [3.         2.6925824  1.80277564 1.41421356]
 [3.31662479 3.20156212 1.11803399 1.41421356]]
test dataの各ラベルindexの最頻値:
 [0 1 1]
result:
 ['a' 'b' 'b']

参考文献

21
29
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
29

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?