3
4

More than 3 years have passed since last update.

ガウスカーネルをpythonでも爆速で計算する

Last updated at Posted at 2020-02-16

はじめに

ガウス過程等で計算が必要になるガウスカーネルをバク速で計算するコードを書きます.

コード

高速化するためにNumbaで記述しました.Numbaによる制約でコード自体は若干冗長になるのが残念ですが,高速化の効果は絶大です.Numbaで高速化するために2乗を掛け算で記述する,numpyの関数を自分で定義した関数内で使わない等の制約がありました.

from numba import jit, void, f8
import numpy as np
import time


@jit(void(f8[:, :], f8[:, :]))
def gauss_gram_mat(x, K):
  n_points = len(x)
  n_dim = len(x[0])
  b = 0
  sgm = 0.2

  for j in range(n_points):
    for i in range(n_points):
      for k in range(n_dim):
        b = (x[i][k] - x[j][k]) / sgm
        K[i][j] += b * b


def gauss_gram_mat_normal(x, K):
  n_points = len(x)
  n_dim = len(x[0])
  b = 0
  sgm = 0.2

  for j in range(n_points):
    for i in range(n_points):
      for k in range(n_dim):
        b = (x[i][k] - x[j][k]) / sgm
        K[i][j] += b * b


n_dim = 10
n_points = 2000
x = np.random.rand(n_points, n_dim)
K = np.zeros((n_points, n_points))

start = time.time()

gauss_gram_mat(x, K)
K = np.exp(- K / 2)

print("Namba: {}".format(time.time() - start))

start = time.time()

gauss_gram_mat_normal(x, K)
K = np.exp(- K / 2)

print("Normal: {}".format(time.time() - start))

検証

1パターンのみですが,上記の点の数及び次元数において通常のコードとNumbaによるコードで計算速度を比較しました.

見たところ500倍近く早くなりました.(内包表記等を活用すれば,Numba無でも早くなりますが,さすがにここまでは不可能.)

Numba: 0.11480522155761719
Normal: 50.70034885406494

補足

Numpyでも検証しました.

import numpy as np
import time

n_dim = 10
n_points = 2000
sgm = 0.2
x = np.random.rand(n_points, n_dim)

now = time.time()
K = np.exp(- 0.5 * (((x - x[:, None]) / sgm) ** 2).sum(axis=2))
print("Numpy: {}".format(time.time() - start))

Numbaの方がNumpyより高速という結果になりました.

Numpy: 0.3936312198638916
3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4