#目的
データ分析で処理が重たいときはいつもCythonの力に頼っているが、簡単なループについてはnumbaのほうが保守性が高いように思えてきた。
どのくらいの高速化のメリットがあるかを確認し、指標値として覚えておく。
環境
Python==3.6.5
numpy==1.14.3
numba==0.38.0
全部anacondaに入っていたもの。
試したこと
まずはimport
from numba import jit
import numpy as np
通常のnp.arrayで計算した場合の関数
def sum_python(arr):
arr_size = len(arr)
result = 0
for i in range(arr_size):
result += arr[i]
return result
それをnumbaでjit化して、計算した場合の関数
(初稿時return resultではなく、returnとなっていたため、for文が無視されているとのご指摘をいただきました。以下修正版です。@yoshi123-xyz さん、ありがとうございました!!)
@jit
def sum_numba(arr):
arr_size = len(arr)
result = 0
for i in range(arr_size):
result += arr[i]
return result
以下、結果。
_arr = np.arange(10000000, dtype=np.int64)
%timeit sum_python(_arr)
1.47 s ± 51.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit sum_numba(_arr)
3.83 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit np.sum(_arr)
5.03 ms ± 9.18 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
なんかおかしい。さすがにsum_numba()が速すぎるような気が。。。
(修正後)
というわけで、np.sum()に近い性能を得られることができました。これはかなり使えそう。