LoginSignup
4
1

More than 1 year has passed since last update.

【k-近傍法】scikit.learnのBallTreeを使ってみた!

Last updated at Posted at 2022-06-10

【k-近傍法】scikit.learnのBallTreeを使ってみた!

<目次>
1.はじめに
2.k-近傍法とは
3.BallTreeの特徴
4.実装
5.参考にしたサイト・書籍

1.はじめに

kaggleの「Foursquare - Location Matching」という大会に取り組んでいく中で,k-近傍法のパッケージであるscikit.learnのBallTreeを使う機会があったため,学んだことや実装する場面・方法をご紹介していきたいと思います!

2.k-近傍法(kNN法)とは

対象とするデータに最も近いデータをk個取ってきて,それらが最も多く所属するクラスに識別する方法.

kNN法.png
                    図1

図1の場合だと,k=3の場合は対象のデータはクラス1に分類され,k=5の場合はクラス2に分類される.

※名前が似ているK-平均法(K-means)とは別物なので注意.
(K-平均法は,データの平均値を使ってK個のクラスタを作るクラスタリングの手法.)

3.BallTreeの特徴

BallTreeとはあるデータ点の近傍を探すのに適したデータ構造のこと

以下ではBallTreeの簡単な概観を図を用いて説明する.

図2ように平面上にデータが与えられた時を考える.
スクリーンショット 2022-06-10 15.31.20.png
                    図2

まずデータをaとbという円領域に分割する.(分割方法については論文1を参照)
スクリーンショット 2022-06-10 15.31.32.png
                    図3

それを各円領域a,bでも繰り返していく.
スクリーンショット 2022-06-10 15.31.42.png
                    図4

分割し終わったものが図5のようになる.
スクリーンショット 2022-06-10 15.31.58.png
                    図5

そして,図5からは図6のような木が得られることとなる.
スクリーンショット 2022-06-10 15.32.12.png
                    図6
このように円領域(Ball)をもとに木(Tree)を作るようなデータ構造であるため,BallTreeという名称になっている.
こうして得られた木と対象のデータqと円領域との距離を用いて,効率良く最近傍探索(最も近いデータを見つけること)や範囲探索(ある範囲内のデータを探すこと)を行うことができる.

4.実装

目的:以下のテストデータから,同じ場所を指しているid同士をmatchさせること
スクリーンショット 2022-06-10 17.02.37.png

指針:BallTreeを用いたk-近傍法により各idと最も近い2点を出力し,その距離が基準以下であればidを出力する.(ここでは「name」「categories」「latitude」「longitude」という4つの特徴量を用いる)

import numpy as np
import pandas as pd 
from sklearn.neighbors import BallTree
from tqdm import tqdm
tree=BallTree(np.deg2rad(test_0[["latitude","longitude"]].values),metric='haversine')

#test_0はテストデータを指す
#np.deg2radで緯度と経度をまとめてラジアンに直している
#'haversine'は球面上の距離を求める方法
#結果を格納するリスト
pois_out=[] 

#近傍数
n=min(20,len(test_0)) 

#マッチさせるidの最大数
max_poi=2

#最大距離
max_dist_cat = 0.00018
max_dist_name = 0.0018
max_dist = max(max_dist_cat, max_dist_name)

for i, row in tqdm(test_0.iterrows()):
    #2点の最近傍データをdist,indに出力
    dist, ind = tree.query(np.deg2rad(np.c_[row['latitude'], row['longitude']]), k = n)
    poi = []
    for d, j in zip(dist[0], ind[0]):
        if d <= max_dist_cat and row['categories'] != '__NAN__' and (row['categories'] in test_0.categories.iloc[j] or test_0.categories.iloc[j] in row['categories']):
            poi.append(test_0.id.iloc[j])
        elif d <= max_dist_name and row['name'] != '__NAN__' and (row['name'].lower() == test_0.name.iloc[j].lower()):
            poi.append(test_0.id.iloc[j])
        if d > max_dist or len(poi) >= max_poi:
            break

    if len(poi) == 0:
        pois_out.append(row['id'])
    else:
        pois_out.append(' '.join(poi))

得られたデータをtestデータに格納すると以下のようになった.
スクリーンショット 2022-06-10 17.02.44.png

与えられていたサンプルのデータと同じ結果になってしまいました...

今後はこれをどのように応用するかを考えていきたいと思います!

5.参考文献

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1