はじめに
ガウス過程等で計算が必要になるガウスカーネルをバク速で計算するコードを書きます.
コード
高速化するためにNumbaで記述しました.Numbaによる制約でコード自体は若干冗長になるのが残念ですが,高速化の効果は絶大です.Numbaで高速化するために2乗を掛け算で記述する,numpyの関数を自分で定義した関数内で使わない等の制約がありました.
from numba import jit, void, f8
import numpy as np
import time
@jit(void(f8[:, :], f8[:, :]))
def gauss_gram_mat(x, K):
n_points = len(x)
n_dim = len(x[0])
b = 0
sgm = 0.2
for j in range(n_points):
for i in range(n_points):
for k in range(n_dim):
b = (x[i][k] - x[j][k]) / sgm
K[i][j] += b * b
def gauss_gram_mat_normal(x, K):
n_points = len(x)
n_dim = len(x[0])
b = 0
sgm = 0.2
for j in range(n_points):
for i in range(n_points):
for k in range(n_dim):
b = (x[i][k] - x[j][k]) / sgm
K[i][j] += b * b
n_dim = 10
n_points = 2000
x = np.random.rand(n_points, n_dim)
K = np.zeros((n_points, n_points))
start = time.time()
gauss_gram_mat(x, K)
K = np.exp(- K / 2)
print("Namba: {}".format(time.time() - start))
start = time.time()
gauss_gram_mat_normal(x, K)
K = np.exp(- K / 2)
print("Normal: {}".format(time.time() - start))
検証
1パターンのみですが,上記の点の数及び次元数において通常のコードとNumbaによるコードで計算速度を比較しました.
見たところ500倍近く早くなりました.(内包表記等を活用すれば,Numba無でも早くなりますが,さすがにここまでは不可能.)
Numba: 0.11480522155761719
Normal: 50.70034885406494
補足
Numpyでも検証しました.
import numpy as np
import time
n_dim = 10
n_points = 2000
sgm = 0.2
x = np.random.rand(n_points, n_dim)
now = time.time()
K = np.exp(- 0.5 * (((x - x[:, None]) / sgm) ** 2).sum(axis=2))
print("Numpy: {}".format(time.time() - start))
Numbaの方がNumpyより高速という結果になりました.
Numpy: 0.3936312198638916