LoginSignup
0
1

More than 3 years have passed since last update.

K近傍法(NearestNeighbors)のPython実装メモ

Last updated at Posted at 2020-08-21

0 はじめに

本記事はsklearnでK近傍法を実装する際の自分用メモとして記載した内容です。他の記事を読んで理解した点や、記憶しておきたい点を図を用いてまとめています。

1 準備

以下の通り、numpyとsklearnからK近傍法で使うNearestNeighborsを読み込みます。

image.png

今回データセットとして、y=xの1次元関数を定義しデータセットを作成します。なお、今回用意をするデータイメージは以下の感じです。

image.png

訓練データセットは9個ずつで構成し[1,9,2]の次元で与えます。また10番目のデータを試験データとして与えます。今回確認する点として、訓練データセット9個に対し、試験データはどの点に最も近いのかを予測するさせます。1次関数でデータを作っているので結果は自明で、訓練データセットの9番目が最も近くなりますが、このシンプルな問題に対し、NearestNeighborsの動作を確認します。

image.png

実行結果は以下の通りです。

image.png
image.png

2 NearestNeighbors用のデータ定義

knn_vector_nは訓練データセットを格納する箱で、knn_vector_n1が試験データセットを格納する箱です。

image.png

for i in tqdm(range(0,len(test_data)-9)):
これはデータセットを作成するためのforループです。

knn_vector_n = test_data.iloc[i:i+9,0:2].values.reshape(-1,2)
これは1セット9個の訓練データを作成するためのスライスです。xとyのデータセットをpandasで与えたので、valueで取り出し、reshape(-1,2)で9行2列のnumpy配列に整形しています。

knn_vector_n1 = test_data.iloc[i+9,0:2].values.reshape(-1,2)
これは試験データを抽出するためのコードで、i+9で10番目のデータを抽出しています。

KNN=NearestNeighbors(n_neighbors=1,lgorithm='ball_tree').fit(knn_vector_n)
この行で各訓練データをNearestNeighborsに読み込ませています。n_neighbors=1は試験データに対し最も近い1つの試験データを探すというものです。実際に試験データに対しどの訓練データが近いのかを探す処理は以下で実施します。
distance, indices = KNN.kneighbors(knn_vector_n1)
第一引数が距離(ユークリッド距離)で、第二引数が何番目の訓練データかを返します。

dist_vec.append(distance)
indi_vec.append(indices)
この2行は、各試験データから計算された最近傍点の距離とデータの場所をforループごとに追加格納している内容です。

3 結果表示

print(dist_vec)

まず、dist_vecですがこれは以下三平方の定理で計算ができる通りの結果が表示されています。
image.png

image.png

もし
'KNN=NearestNeighbors(n_neighbors=2,algorithm='ball_tree').fit(knn_vector_n)`
の'n_neighbor=2'として宣言をした場合、テストデータから最も近い2点を返します。

image.png

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1