Inspired by numpyで明示的にループを書くと極端に遅くなる
上述の記事では明示的にforを書くと極端に遅くなるという話でした。
例えば
def matmul1(a, b):
lenI = a.shape[0]
lenJ = a.shape[1]
lenK = b.shape[1]
c = np.zeros((lenI, lenJ))
for i in range(lenI):
for j in range(lenJ):
for k in range(lenK):
c[i, j] += a[i, k] * b[k, j]
return c
のようなコードはnp.dot
と比べて遅くなります。
%timeit matmul1(a, b)
1 loops, best of 3: 12.9 s per loop
%timeit np.dot(a, b)
10 loops, best of 3: 20.7 ms per loop
手元のノートPCで計算しているので遅いです。またatlas/mklはリンクしてません。
ここでNumbaを使います。
import numba
@numba.jit # ここだけ追加
def matmul1_jit(a, b):
lenI = a.shape[0]
lenJ = a.shape[1]
lenK = b.shape[1]
c = np.zeros((lenI, lenJ))
for i in range(lenI):
for j in range(lenJ):
for k in range(lenK):
c[i, j] += a[i, k] * b[k, j]
return c
これはPythonのコードをLLVMを使ってJITコンパイルするため、非常に高速に実行できます。
初回の呼び出しはコンパイルする時間が含まれるので、それ以降の呼び出しで速度を計測すると:
%timeit matmul1_jit(a, b)
10 loops, best of 3: 24.4 ms per loop
このように1行追加するだけでnp.dot
より同程度(2割遅いくらい)になりました。