LoginSignup
1
0

More than 1 year has passed since last update.

Pythonで3次元kd木(kd-tree)探索

Last updated at Posted at 2022-08-22

Unityで点群データの効率的な探索をするために3次元のkd木探索法が使えるかなと思ったのですが、C#で書く前にグラフを出力して確認しながらアルゴリズムを理解した方が良いと思ったので一度Pythonで実装してみます。

....簡単に実装したいなら既にscipyなどのライブラリで提供されているのでそちらを使うと良いと思います。

参考

以下の@fj-thさんの記事を読んでいることを前提とします。
アルゴリズムの説明とC言語でkd-treeの実装をしています。

2次元のkd木

3次元のkd木は、2次元のkd木を拡張するれば良いので、まず@fj-thさんのC言語のkd-木(リポジトリ)をPythonに書き換えます。
ただPythonで書き換えただけなので説明は割愛します。
以下のノートブックを参照ください。

3次元のkd木

kd-treeクラス

2次元のkd木を拡張する形で実装します。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
"""
特定列を基準にソート
axis_xy : xなら0 yなら1 zなら2
"""
def sortxyz(arr:np.ndarray,axis_xyz:int,offset:int = 0):
    axis = axis_xyz + offset # offset
    return arr[np.argsort(arr[:,axis])]

# test
arr = np.array([[1000000,2,50],
               [2000000,0,100],
               [3000000,1,200]])
print(sortxyz(arr,0))

x_col = 0
y_col = 1
z_col = 2
class Node:
    def set_node(self,points:np.ndarray,right:int,depth:int):
        if right < 0:
            return None
        elif right == 0:
            return self.set_leaf(points[right],depth)
        
        axis = depth % 3
        sorted_points = sortxyz(points[:right + 1],axis)
        if axis == 0:
            self.left_most = sorted_points[0][x_col]
            self.right_most = sorted_points[right][x_col]
        elif axis == 1:
            self.bottom_most = sorted_points[0][y_col]
            self.top_most = sorted_points[right][y_col]
        else:
            self.front_most = sorted_points[0][z_col]
            self.back_most = sorted_points[right][z_col]
            
        median = int(right / 2)
        self.border = (sorted_points[median][x_col:] + sorted_points[median+1][x_col:])/2.0
        self.location = sorted_points[median]
        self.depth = depth
        self.right_child = Node().set_node(sorted_points[median
                                                          + 1:],right -( median + 1),depth + 1)
        self.left_child = Node().set_node(sorted_points,median,depth + 1)
        
        #ここから先は関連付け
        if not(axis == 1):
            if(not(self.right_child is None) and not(self.left_child is None)):
                self.top_most = self.right_child.top_most if self.right_child.top_most > self.left_child.top_most else self.left_child.top_most
                self.bottom_most = self.right_child.bottom_most if self.right_child.bottom_most < self.left_child.bottom_most else self.left_child.bottom_most
            elif not(self.right_child is None):
                self.top_most = self.right_child.top_most
                self.bottom_most = self.right_child.bottom_most
            elif not(self.left_child is None):
                self.top_most = self.left_child.top_most
                self.bottom_most = self.left_child.bottom_most
            else:
                self.top_most = self.location[y_col]
                self.bottom_most = self.location[y_col]
        if not(axis == 0):
            if(not(self.right_child is None) and not(self.left_child is None)):
                self.right_most = self.right_child.right_most if self.right_child.right_most > self.left_child.right_most else self.left_child.right_most
                self.left_most = self.right_child.left_most if self.right_child.left_most < self.left_child.left_most else self.left_child.left_most
            elif not(self.right_child is None):
                self.right_most = self.right_child.right_most
                self.left_most = self.right_child.left_most
            elif not(self.left_child is None):
                self.right_most = self.left_child.right_most
                self.left_most = self.left_child.left_most
            else:
                self.right_most = self.location[x_col]
                self.left_most = self.location[x_col]
        if not(axis == 2):
            if(not(self.right_child is None) and not(self.left_child is None)):
                self.back_most = self.right_child.back_most if self.right_child.back_most > self.left_child.back_most else self.left_child.back_most
                self.front_most = self.right_child.front_most if self.right_child.front_most < self.left_child.front_most else self.left_child.front_most
            elif not(self.right_child is None):
                self.back_most = self.right_child.back_most
                self.front_most = self.right_child.front_most
            elif not(self.left_child is None):
                self.back_most = self.left_child.back_most
                self.front_most = self.left_child.front_most
            else:
                self.back_most = self.location[z_col]
                self.front_most = self.location[z_col]
                
        return self
    def set_leaf(self,location:np.ndarray,depth:int):
        self.location = location
        self.left_child = None
        self.right_child = None
        self.depth = depth
        self.left_most = location[x_col]
        self.right_most = location[x_col]
        self.top_most =   location[y_col]
        self.bottom_most = location[y_col]
        self.front_most = location[z_col]
        self.back_most = location[z_col]
        return self
    def is_contained(self,sx:int,tx:int,sy:int,ty:int,sz:int,tz:int):
        """
        > regionノードから始まるsubtreeの要素が指定領域にすっぽり収まっているか否かを返す。  
        > 各ノードに、自分以下の子の最大、最小値をもたせているので、それを参照するだけで判定できる。
        実装参考のis_contained関数より引用
        """
        return not(self.left_most < sx or self.right_most > tx or \
                    self.top_most > ty or self.bottom_most < sy or \
                    self.front_most < sz or self.back_most > tz)
class Tree:
    def build(self,points):
        node = Node().set_node(points,len(points) - 1,0)
        self.top_node = node
        return node
    def search(self,r:list):
        x,y,z = r
        sx,tx = x
        sy,ty = y
        sz,tz = z
        search_results = []
        def _search(v:Node):
            nonlocal search_results,sx,tx,sy,ty
            if v.right_most < sx or v.left_most > tx or \
                v.bottom_most > ty or v.top_most < sy or \
                v.front_most > tz or v.back_most < sz:
                return;
            if v.left_child is None and v.right_child is None:
                if sx <= v.location[x_col] and sy <= v.location[y_col] and  sz <= v.location[z_col] and\
                    tx >= v.location[x_col] and ty >= v.location[y_col] and  tz >= v.location[z_col]:
                    self.search_results.append(v.location)
                    return
            if not(v.left_child is None):
                if(v.left_child.is_contained(sx,tx,sy,ty,sz,tz)):
                    search_results += self.report_subtree(v.left_child)
                else:
                    _search(v.left_child)
            if not(v.right_child is None):
                if(v.right_child.is_contained(sx,tx,sy,ty,sz,tz)):
                    search_results += self.report_subtree(v.right_child)
                else:
                    _search(v.right_child)
        _search(self.top_node)
        return search_results
    def report_subtree(self,node:Node = None,parent:Node=None,parent_border_axis:tuple = None,draw_border:tuple = None):
        """
        draw_border:tuple = (ax,max_x,min_x,max_y,min_y,max_z,min_z)
        """
        if node is None:
            node = self.top_node
        if node.left_child is None and node.right_child is None:
            return [node.location]
        pba = parent_border_axis
        if not (draw_border is None) :
            def _plot(axis:int,s,e):
                if axis == 0:
                    x = [s[0],e[0]]
                    y = [s[1],e[1]]
                    z = np.linspace(s[2], e[2])
                    Y, Z = np.meshgrid(y, z)
                    X = np.array([x] * Y.shape[0])
                    ax.plot_surface(X, Y, Z) 
                elif axis == 1:
                    x = [s[0],e[0]]
                    y = [s[1],e[1]]
                    z = np.linspace(s[2], e[2])
                    X, Z = np.meshgrid(x, z)
                    Y = np.array([y] * X.shape[0])
                    ax.plot_surface(X, Y, Z) 
                else:
                    x = np.linspace(s[0], e[0])
                    y = [s[1],e[1]]
                    z = [s[2],e[2]]
                    Z, X = np.meshgrid(z,x)
                    Y = np.array([y] * Z.shape[0])
                    ax.plot_surface(X, Y, Z)
            ax,xmax,xmin,ymax,ymin,zmax,zmin = draw_border
            s = node.border.copy()
            e = [0,0,0]
            if parent is None:
                s[1] = ymax
                s[2] = zmax
                e = np.array([s[0],ymin,zmin])
                pba = (s[0],None,s[2])
                _plot(0,s,e)
            else:
                if not(pba[0] is None) and not(pba[2] is None): #前がxの区切り(yz線) => 今がyの区切り(xz線)
                    x = xmax if(pba[0] < s[0]) else xmin
                    z = zmax if(pba[2] < s[2]) else zmin
                    s[0] = pba[0]
                    s[2] = pba[2]
                    e = np.array([x,s[1],z])
                    pba = (s[0],s[1],None)
                    _plot(1,s,e)
                elif not(pba[1] is None) and not(pba[0] is None): #前がyの区切り(xz線) => 今がzの区切り(xy線)

                    x = xmax if(pba[0] < s[0]) else xmin
                    y = ymax if(pba[1] < s[1]) else ymin
                    s[1] = pba[1]
                    s[0] = pba[0]
                    e = np.array([x,y,s[2]])
                    pba = (None,s[1],s[2])
                    _plot(2,s,e)
                else: #前がzの区切り(xy線) => 今がxの区切り(yz線)
                    y = ymax if(pba[1] < s[1]) else ymin
                    z = zmax if(pba[2] < s[2]) else zmin
                    s[2] = pba[2]
                    s[0] = pba[0]
                    e = np.array([s[0],y,z])
                    pba = (s[0],None,s[2])
                    _plot(0,s,e)
            
        arr = []
        if not (node.left_child is None):
            arr += self.report_subtree(node.left_child,node,pba,draw_border)
        if not (node.right_child is None):
            arr += self.report_subtree(node.right_child,node,pba,draw_border)
        return arr

追加したNodeクラスのfront_mostとback_mostが各ノードのz軸の範囲を持ちます。

テストデータ作成とkd-tree構築

テストデータを使って境界を表示してみます。

# test
points = [[1,2,2],
          [2,4,4],
          [5,6,3],
          [7,8,6],
          [3,10,4],
          [11,11,3],
         [10,10,4],
         [0,8,7]]
points = np.array(points)
tree = Tree()
node = tree.build(points)
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(points[:,x_col],points[:,y_col],points[:,z_col])
plt.show()
%matplotlib notebook
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(points[:,x_col],points[:,y_col],points[:,z_col])
tree.report_subtree(draw_border=(ax,points[:,x_col].max(),points[:,x_col].min(),points[:,y_col].max(),points[:,y_col].min(),points[:,z_col].max(),points[:,z_col].min()))
#ax.view_init(elev=10, azim=50)
plt.show()

image.png
分割結果
image.png

グルグルグラフを回転させてみると点が分けられていることがわかると思います。

探索

次に探索してみます。search関数の引数にx,y,zの範囲を指定すると返り値にその範囲内の点が返されます。

results = np.array(tree.search([[0,10],[8,10],[0,6]]))
print(results)
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(results[:,x_col],results[:,y_col],results[:,z_col])
plt.xlim([points[:,x_col].min()-1,points[:,x_col].max()+1])
plt.ylim([points[:,y_col].min()-1,points[:,y_col].max()+1])
plt.show()

"""
results : [[ 3 10  4]
 [ 7  8  6]
 [10 10  4]]
"""

image.png

(グラフがわかりにくいですが)できました。
どちらかというと3次元に拡張するより、アルゴリズムの理解と境界のグラフ表示に時間がかかりました。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0