LoginSignup
1
3

More than 5 years have passed since last update.

メモ: numpyで全点間の距離を求める

Posted at

多次元座標が入った配列$x$と$y$について、$z_{ij}$に$x_i$と$y_j$の距離が入ったような配列$z$をどのように得るかと言ったお話。
numpyのブロードキャスト機能をうまく使うと簡潔にかけて50倍ぐらい早くなった


import timeit
x = np.random.random(10000).reshape(500,20)
y = np.random.random(10000).reshape(500,20)
def calic_dist(x,y):
    z = np.zeros((len(x),len(x)))
    for i in range(len(x)):
        for j in range(len(x)):
            z[i][j] = np.sum((x[i] -y[j] ** 2))
    z = np.sqrt(z)
    return z

def calic_dist_np(x,y):
    try:
        assert isinstance(x,np.ndarray)
    except AssertionError:
        x = np.array(x)
    try:
        assert isinstance(y,np.ndarray)
    except AssertionError:
        y = np.array(x)
    assert x.shape == y.shape
    z = x.reshape(x.shape[0],1,x.shape[1]) - y.reshape(1,x.shape[0],x.shape[1])
    zz = np.sum(z**2, axis=2)
    zz = np.sqrt(zz)
    return zz

%timeit z = calic_dist(x,y)
%timeit zz = calic_dist_np(x,y)
assert z == zz

実行結果はこんな感じ

1.34 s ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
23.3 ms ± 532 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

ブロードキャストを使っているのはz = x.reshape(x.shape[0],1,x.shape[1]) - y.reshape(1,x.shape[0],x.shape[1])
の部分。なんにせよ50倍ぐらいは速くなっていそうだ。
ただし普通にアルゴリズムとしては$O(n^2)$だと思われるので基本的にはこのような計算は避けたいところではあるのだろう。

1
3
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3