LoginSignup
6
6

More than 5 years have passed since last update.

球面上の近傍探索 Neighbor search on sphere

Last updated at Posted at 2018-01-13

モチベーション

  • 球面上にランダムに分布しているサンプル点(n ~ 10,000)のうち、ある地点の周辺にある点をすべて列挙したい
  • クエリの数が多い(q ~ 100,000)ので高速化したい

モチベーション

記事概要

  • 高速化をいろいろ試みた
  • 松・竹・梅の3つのアルゴリズムを比較
  • 決定版は松コースを参照のこと

梅コース: 線形探索

  • 愚直に総当たりで二点間距離が条件を満たすものを探す
  • 遅いが実装が簡単
  • データ点 n = 10000, クエリ数 q = 1000に対して19.1秒
  • 計算コストはデータ点数に対して線形に増加する O(nq)
def linear_search(points, queries):
    num_total = 0
    for q in queries:
        lat_query, lon_query, max_dist = q
        indices = []
        for i, p in enumerate(points):
            if dist(p[0], p[1], lat_query, lon_query) <= max_dist:
                indices.append(i)
        num_total += len(indices)
    print(num_total)

竹コース: 平方分割

  • データ点 n を、√n 個程度のboxに分割して探索する
  • binに対する探索、その内部での探索がそれぞれ O(√n) で行えるため、総計算量は O(q√n) となり線形探索より速い
  • データ点 n = 10000, クエリ数 q = 1000に対して4.9秒
  • この例では経緯度をそれぞれ12分割したが、あまり高速化できなかった。nがより大きい場合、線形探索に対する優位はより大きくなる(はず)
from math import *
from itertools import chain

class Box():
    def __init__(self, clat, clon, radius):
        self.center = [clat, clon]
        self.radius = radius
        self.points = []

    def add_point(self, point):
        self.radius = max(self.radius, dist(self.center[0], self.center[1], point[0], point[1]))
        self.points.append(point)

def square_root_decomposition(points, queries):
    DECOMP_LAT = 12
    DECOMP_LON = 12
    boxes = [[None for j in range(DECOMP_LON)] for i in range(DECOMP_LAT)]

    for i in range(DECOMP_LAT):
        for j in range(DECOMP_LON):
            clat = ((i + 0.5) / DECOMP_LAT) * pi - pi / 2
            clon = ((j + 0.5) / DECOMP_LON) * 2 * pi - pi
            radius = 0.0
            boxes[i][j] = Box(clat, clon, radius)

    for p in points:
        i = lat_to_i(p[0], DECOMP_LAT)
        j = lon_to_j(p[1], DECOMP_LON)
        boxes[i][j].add_point((p[0], p[1]))

    num_total = 0
    for q in queries:
        lat_query, lon_query, max_dist = q
        res = []
        for box in chain.from_iterable(boxes):
            flag = bquery(box, [lat_query, lon_query], max_dist)
            if flag == 2:
                res += box.points
            elif flag == 1:
                for p in box.points:
                    if dist(lat_query, lon_query, p[0], p[1]) <= max_dist:
                        res.append(p)
        num_total += len(res)
    print(num_total)

def bquery(box, pq, pd):
    cdist = dist(box.center[0], box.center[1], pq[0], pq[1])
    if cdist + box.radius <= pd:
        return 2
    elif cdist - box.radius <= pd:
        return 1
    else:
        return 0

def lat_to_i(lat, mdiv):
    i = int(mdiv * (lat + pi / 2) / pi)
    if i == mdiv:
        i = mdiv - 1
    assert 0 <= i < mdiv
    return i

def lon_to_j(lon, mdiv):
    j = int(mdiv * (lon + pi) / (2 * pi))
    assert 0 <= j < mdiv
    return j

球面上に拡張する上でのヒント

  • 3点 A, B, Cに対して、球面上の距離は三角不等式 AC <= AB + BC を常に満たす
  • 各箱との交差判定は、クエリ座標、クエリ半径、箱の中心座標、箱の半径(箱の中で箱の中心から最も遠い点までの距離)が使える
    • (箱中心とクエリ座標との距離)≦(クエリ半径 - 箱の半径)なら箱全体がクエリ半径に含まれる
    • (箱中心とクエリ座標との距離)≦(クエリ半径 + 箱の半径)なら箱の一部がクエリ半径に含まれうる
    • それ以外の場合、箱はクエリ半径と交わらない

image.png

既知の問題と対処法

  • 極周辺のクエリでは、箱が南北に細長くなってしまうため、箱半径を用いた当たり判定は無駄が多くなる
    image.png

  • 単純な経緯度での分割をやめ、例えば次のような分割で、より箱の形を一様にすることで対処できる
    image.png
    (画像取得元 https://www.gfdl.noaa.gov/wp-content/uploads/2016/05/grid_1.jpg)

松コース: 3次元ユークリッド空間への埋め込みとkd木の使用

  • 球面上で考えると当たり判定が面倒だし、雑な扱いをすると極の周りでパフォーマンスが落ちる
  • 2次元球面(lat, lon)を3次元ユークリッド空間(x, y, z)に 埋め込め ば、諸々の問題が解決して実装が楽
    • 球面上の距離(大円距離・弧長)は3次元ユークリッド空間上の距離(弦長)と一対一に対応することを利用する
  • 既存のアルゴリズムやライブラリ(kd木八分木 、、、)が使える
    • ここでは、内部がcで書かれたkd木(scipy.spatial.cKDTree)を使った。
    • データ点 n = 10000, クエリ数 q = 1000に対して0.06秒。爆速。
  • Pure pythonで実装した場合(cKDTree→KDTree)は、竹コースと同等程度の速度
    • 埋め込みにより次元が高くなるので、パフォーマンス的にはマイナス(curse of dimensionality of kd tree)。極対策等をチューンした二次元平方分割/四分木/kd木には負けそう
    • 実際、「竹」と同じ経緯度分割、半径を使った当たり判定で四分木を書いたら3次元ユークリッド空間埋め込みkd木より3倍くらい速かった
from scipy import spatial

def latlon_to_xyz(lat, lon):
    x = cos(lat) * cos(lon)
    y = cos(lat) * sin(lon)
    z = sin(lat)
    return x, y, z

def dist_on_sphere_to_cartesian(dist_on_sphere):
    return 2.0 * sin(dist_on_sphere * 0.5 / R_SPHERE)

def use_ckd_tree(points, queries):
    points_xyz = []
    for p in points:
        x, y, z = latlon_to_xyz(*p)
        points_xyz.append((x, y, z))
    tree = spatial.cKDTree(points_xyz)

    num_total = 0
    for q in queries:
        lat_query, lon_query, max_dist = q
        x_query, y_query, z_query = latlon_to_xyz(lat_query, lon_query)
        indices = tree.query_ball_point(
            [x_query, y_query, z_query], dist_on_sphere_to_cartesian(max_dist))
        num_total += len(indices)
    print(num_total)

まとめ

  • Pythonユーザは3次元ユークリッド空間に埋め込んだうえでscipy.spatial.cKDTreeを使おう
  • 最高速度を狙うなら上手い2次元分割を考えてC/C++/Fortranで書こう(丸投げ)

追記

  • kkddさんからsklearn.neighbors.BallTreeを使えばいい、と情報提供をいただきましたので、試してみました。
    • データ点 n = 10000, クエリ数 q = 1000に対して0.16秒と、cKDTreeに及ばないまでも十分な高速化。
    • これはお手軽ですね。ありがとうございます。
import sklearn.neighbors

def use_ball_tree(points, queries):
    tree = sklearn.neighbors.BallTree(np.array(points), leaf_size=2, metric="haversine")
    num_total = 0
    for q in queries:
        lat_query, lon_query, max_dist = q
        ind = tree.query_radius([[lat_query, lon_query]], r=max_dist/R_SPHERE)
        num_total += len(ind[0])
    print(num_total)

コード全体

#!/usr/bin/env python3

from math import *
import numpy as np
from scipy import spatial
from itertools import chain

R_SPHERE = 6400000.0  # Earth ~6400km
NPOINTS = 10000
NQUERY = 1000

def rand_uniform_lat_lon():
    while 1:
        x, y, z = np.random.uniform(-1.0, 1.0, 3)
        r = (x ** 2 + y ** 2 + z ** 2) ** 0.5
        if 0.0 < r <= 1.0:
            break
    lat, lon = xyz_to_latlon(x / r, y / r, z / r)
    return lat, lon

def xyz_to_latlon(x, y, z):
    # -pi/2 <= lat <= pi/2
    # -pi   <  lon <= pi
    lat = asin(z)
    lon = atan2(y, x)
    return lat, lon

def latlon_to_xyz(lat, lon):
    x = cos(lat) * cos(lon)
    y = cos(lat) * sin(lon)
    z = sin(lat)
    return x, y, z

def dist(lat1, lon1, lat2, lon2):
    return R_SPHERE * 2 * asin(sqrt(sin(abs(lat1 - lat2) * 0.5) ** 2
        + cos(lat1) * cos(lat2) * sin(abs(lon1 - lon2) * 0.5) ** 2))

def dist_on_sphere_to_cartesian(dist_on_sphere):
    return 2.0 * sin(dist_on_sphere * 0.5 / R_SPHERE)

def linear_search(points, queries):
    num_total = 0
    for q in queries:
        lat_query, lon_query, max_dist = q
        indices = []
        for i, p in enumerate(points):
            if dist(p[0], p[1], lat_query, lon_query) <= max_dist:
                indices.append(i)
        num_total += len(indices)
    print(num_total)

def use_ckd_tree(points, queries):
    points_xyz = []
    for p in points:
        x, y, z = latlon_to_xyz(*p)
        points_xyz.append((x, y, z))
    tree = spatial.cKDTree(points_xyz)

    num_total = 0
    for q in queries:
        lat_query, lon_query, max_dist = q
        x_query, y_query, z_query = latlon_to_xyz(lat_query, lon_query)
        indices = tree.query_ball_point(
            [x_query, y_query, z_query], dist_on_sphere_to_cartesian(max_dist))
        num_total += len(indices)
    print(num_total)

class Box():
    def __init__(self, clat, clon, radius):
        self.center = [clat, clon]
        self.radius = radius
        self.points = []

    def add_point(self, point):
        self.radius = max(self.radius, dist(self.center[0], self.center[1], point[0], point[1]))
        self.points.append(point)

def square_root_decomposition(points, queries):
    DECOMP_LAT = 12
    DECOMP_LON = 12
    boxes = [[None for j in range(DECOMP_LON)] for i in range(DECOMP_LAT)]

    for i in range(DECOMP_LAT):
        for j in range(DECOMP_LON):
            clat = ((i + 0.5) / DECOMP_LAT) * pi - pi / 2
            clon = ((j + 0.5) / DECOMP_LON) * 2 * pi - pi
            radius = 0.0
            boxes[i][j] = Box(clat, clon, radius)

    for p in points:
        i = lat_to_i(p[0], DECOMP_LAT)
        j = lon_to_j(p[1], DECOMP_LON)
        boxes[i][j].add_point((p[0], p[1]))

    num_total = 0
    for q in queries:
        lat_query, lon_query, max_dist = q
        res = []
        for box in chain.from_iterable(boxes):
            flag = bquery(box, [lat_query, lon_query], max_dist)
            if flag == 2:
                res += box.points
            elif flag == 1:
                for p in box.points:
                    if dist(lat_query, lon_query, p[0], p[1]) <= max_dist:
                        res.append(p)
        num_total += len(res)
    print(num_total)

def bquery(box, pq, pd):
    cdist = dist(box.center[0], box.center[1], pq[0], pq[1])
    if cdist + box.radius <= pd:
        return 2
    elif cdist - box.radius <= pd:
        return 1
    else:
        return 0

def lat_to_i(lat, mdiv):
    i = int(mdiv * (lat + pi / 2) / pi)
    if i == mdiv:
        i = mdiv - 1
    assert 0 <= i < mdiv
    return i

def lon_to_j(lon, mdiv):
    j = int(mdiv * (lon + pi) / (2 * pi))
    assert 0 <= j < mdiv
    return j

def main():
    points = []
    for i in range(NPOINTS):
        lat, lon = rand_uniform_lat_lon()
        points.append((lat, lon))

    queries = []
    for j in range(NQUERY):
        lat_query, lon_query = rand_uniform_lat_lon()
        max_dist = 1000000  # 1000km
        queries.append((lat_query, lon_query, max_dist))

    linear_search(points, queries)
    use_ckd_tree(points, queries)
    square_root_decomposition(points, queries)

if __name__ == "__main__":
    main()

参考にしたもの

6
6
7

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
6