0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Python:ベクトルの差を高速で処理

Last updated at Posted at 2023-06-07

問題設定

  • グループA 3次元ベクトル(x,y,z) (N_A)個
  • グループB 3次元ベクトル(u,v,w) (N_B)個

があるとき、Aの要素とBの要素のペア(N_A * N_B)個の中で、ペア同士のベクトルの距離の最小値を求めたい。

数学的に表現

min_{i,j}||(x,y,z)_i - (u,v,w)_j||

min_x_u.png

Python で実装

ナイーブに、(N_A x N_B)回のループを回しても良いが、numpy.tile を使うとすっきりと高速化できる。

import numpy

x = data_x # 1次元のnumpy.array()を想定
y = data_y
z = data_z
        
u = data_u
v = data_v
w = data_w
       
N_A = len(data_x)
N_B = len(data_u)

matrix_A = numpy.vstack([x, y, z])
matrix_B = numpy.vstack([u, v, w])
        
matrix_A_T__xNB = numpy.tile(matrix_A.T, N_B).reshape(N_A*N_B,3)
matrix_B__xNA_T = numpy.tile(matrix_B, N_A).T
diff_AB_2 = (matrix_A_T__xNB - matrix_B__xNA_T)**2
dmin_AB = numpy.sqrt(numpy.min(numpy.sum(diff_AB_2, axis=1)))
argmin_diff_AB_2 = numpy.argmin(numpy.sum(diff_AB_2, axis=1)) # index

print(dmin_AB)
print(argmin_diff_AB_2)

print(int(argmin_diff_AB_2 / N_B)) # どの i のときに最小値となるかを示すindex (0<=i<N_A)
print(   (argmin_diff_AB_2 % N_A)) # どの j のときに最小値となるかを示すindex (0<=j<N_B)

example

import numpy

def doit():
    # A
    x = [1, 2, 3, 0] # Jr_i
    y = [0, 0, 0, 0] # Jz_i
    z = [0, 0, 0, 0] # Jphi_i

    # B
    u = numpy.array([4,5]) # _Jr_c
    v = numpy.array([0,0]) # _Jz_c
    w = numpy.array([0,0]) # _Jphi_c

    N_A = len(x) #4   # N_MC
    N_B = len(u) #2   # _k 
    
    matrix_A = numpy.vstack([x, y, z])
    matrix_B = numpy.vstack([u, v, w])

    matrix_A_T__xNB = numpy.tile(matrix_A.T, N_B).reshape(N_A*N_B,3)
    matrix_B__xNA_T = numpy.tile(matrix_B, N_A).T

    print()
    print('# matrix_A')
    print(matrix_A)
    print()
    print('# matrix_B')
    print(matrix_B)
    
    print()
    print('# matrix_A_T__xNB')
    print(matrix_A_T__xNB)
    
    print()
    print('# matrix_B__xNA_T')
    print(matrix_B__xNA_T)
    
    
    diff_AB_2 = (matrix_A_T__xNB - matrix_B__xNA_T)**2
    dmin_AB = numpy.sqrt(numpy.min(numpy.sum(diff_AB_2, axis=1)))
    argmin_diff_AB_2 = numpy.argmin(numpy.sum(diff_AB_2, axis=1))

    print()
    print('# diff_AB_2')
    print(diff_AB_2)
    
    
    print()
    print('# argmin_diff_AB_2')
    print(argmin_diff_AB_2)
    
    print()
    print('# optimal value of i (0<=i<N_A)')
    print('# int(argmin_diff_AB_2/N_B)')
    print(int(argmin_diff_AB_2/N_B))
    
    print()
    print('# optimal value of j (0<=j<N_B)')
    print('# argmin_diff_AB_2 % N_A')
    print(argmin_diff_AB_2 % N_A)
doit()

### 出力 ###

# matrix_A
[[1 2 3 0]
 [0 0 0 0]
 [0 0 0 0]]

# matrix_B
[[4 5]
 [0 0]
 [0 0]]

# matrix_A_T__xNB
[[1 0 0]
 [1 0 0]
 [2 0 0]
 [2 0 0]
 [3 0 0]
 [3 0 0]
 [0 0 0]
 [0 0 0]]

# matrix_B__xNA_T
[[4 0 0]
 [5 0 0]
 [4 0 0]
 [5 0 0]
 [4 0 0]
 [5 0 0]
 [4 0 0]
 [5 0 0]]

# diff_AB_2
[[ 9  0  0]
 [16  0  0]
 [ 4  0  0]
 [ 9  0  0]
 [ 1  0  0]
 [ 4  0  0]
 [16  0  0]
 [25  0  0]]

# argmin_diff_AB_2
4

# optimal value of i (0<=i<N_A)
# int(argmin_diff_AB_2/N_B)
2

# optimal value of j (0<=j<N_B)
# argmin_diff_AB_2 % N_A
0
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?