#0 はじめに
本記事はsklearnでK近傍法を実装する際の自分用メモとして記載した内容です。他の記事を読んで理解した点や、記憶しておきたい点を図を用いてまとめています。
#1 準備
以下の通り、numpyとsklearnからK近傍法で使うNearestNeighbors
を読み込みます。
今回データセットとして、y=x
の1次元関数を定義しデータセットを作成します。なお、今回用意をするデータイメージは以下の感じです。
訓練データセットは9個ずつで構成し[1,9,2]の次元で与えます。また10番目のデータを試験データとして与えます。今回確認する点として、訓練データセット9個に対し、試験データはどの点に最も近いのかを予測するさせます。1次関数でデータを作っているので結果は自明で、訓練データセットの9番目が最も近くなりますが、このシンプルな問題に対し、NearestNeighbors
の動作を確認します。
実行結果は以下の通りです。
#2 NearestNeighbors用のデータ定義
knn_vector_n
は訓練データセットを格納する箱で、knn_vector_n1
が試験データセットを格納する箱です。
for i in tqdm(range(0,len(test_data)-9)):
これはデータセットを作成するためのforループです。
knn_vector_n = test_data.iloc[i:i+9,0:2].values.reshape(-1,2)
これは1セット9個の訓練データを作成するためのスライスです。xとyのデータセットをpandasで与えたので、value
で取り出し、reshape(-1,2)
で9行2列のnumpy配列に整形しています。
knn_vector_n1 = test_data.iloc[i+9,0:2].values.reshape(-1,2)
これは試験データを抽出するためのコードで、i+9
で10番目のデータを抽出しています。
KNN=NearestNeighbors(n_neighbors=1,lgorithm='ball_tree').fit(knn_vector_n)
この行で各訓練データをNearestNeighbors
に読み込ませています。n_neighbors=1
は試験データに対し最も近い1つの試験データを探すというものです。実際に試験データに対しどの訓練データが近いのかを探す処理は以下で実施します。
distance, indices = KNN.kneighbors(knn_vector_n1)
第一引数が距離(ユークリッド距離)で、第二引数が何番目の訓練データかを返します。
dist_vec.append(distance)
indi_vec.append(indices)
この2行は、各試験データから計算された最近傍点の距離とデータの場所をforループごとに追加格納している内容です。
#3 結果表示
print(dist_vec)
まず、dist_vec
ですがこれは以下三平方の定理で計算ができる通りの結果が表示されています。
もし
'KNN=NearestNeighbors(n_neighbors=2,algorithm='ball_tree').fit(knn_vector_n)`
の'n_neighbor=2'として宣言をした場合、テストデータから最も近い2点を返します。