Numbaのnp.sum,min,maxなどについて
他の記事でNumbaのnp.sumは遅いのかを調査する機会があった。
その時にその正体について気になった(というかわかった気がした)ので、覚書。
ついでにmin,maxについても考えた。
Numbaのnp.sum
コード
確認用コード(折りたたみ)
from numba import njit,prange
import numpy as np
import timeit
@njit
def my_sum0(arr):
return np.sum(arr)
@njit
def my_sum1(arr):
return_num = 0.
for i in arr.flat:
return_num += i
return return_num
@njit(parallel=True, error_model="numpy" ,fastmath=True)
def my_sum2(arr):
return_num = 0.
for i in prange(arr.size):
return_num += arr.flat[i]
return return_num
N=1000000
d1, d2, d3 = 4, 8, 3
d0 = int((N//(d1*d2*d3))**0.33)
arr = np.random.rand(d0*d1, d0*d2, d0*d3)
print(np.sum(arr),my_sum0(arr),my_sum1(arr),my_sum2(arr))
%timeit np.sum(arr)
%timeit my_sum0(arr)
%timeit my_sum1(arr)
%timeit my_sum2(arr)
結果
np.sum(arr) : 168 μs ± 1.07 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
my_sum0(arr) : 843 μs ± 2.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
my_sum1(arr) : 842 μs ± 1.95 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
my_sum2(arr) : 89.2 μs ± 5.59 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Numbaのnp.sum
のまとめ
多分Numba内でのnp.sum
は
@njit
def my_sum1(arr):
return_num = 0.
for i in arr.flat:
return_num += i
return return_num
と同じだよねと思って測ってみた感じ全く同じタイムが出た。
何回かarrをいじってみたがほぼ同じタイムなので、まあ一緒なんだろうと思う。
てかnumpyが尋常じゃなく速い。すごいなあ。
my_sum2
では並列化とerror_model,fastmathをつけてみたところ、ぎりnumpyを超えた。
ただ、sumなんて並列化の中でも使うので実用的ではない。
きっとCPUコア数によるだろうし、Nで逆転もしそう。
Numbaのnp.min
コード
確認用コード(折りたたみ)
from numba import njit,prange,get_num_threads
import numpy as np
import timeit
@njit
def my_min0(arr):
return np.min(arr)
@njit
def my_min1(arr):
return_min = np.finfo(np.float64).max
for i in arr.flat:
if return_min>i:
return_min = i
return return_min
@njit(parallel=True, error_model="numpy" ,fastmath=True)
def my_min2(arr):
num_threads = get_num_threads()
return_min_arr = np.empty((num_threads),dtype="float64")
for i in prange(num_threads):
return_min = arr.flat[i]
for j in range(num_threads + i, arr.size, num_threads):
if return_min > arr.flat[j]:
return_min = arr.flat[j]
return_min_arr[i] = return_min
return np.min(return_min_arr)
N=1000000
d1, d2, d3 = 4, 8, 3
d0 = int((N//(d1*d2*d3))**0.33)
arr = np.random.rand(d0*d1, d0*d2, d0*d3)
print(np.min(arr),my_min0(arr),my_min1(arr),my_min2(arr))
%timeit np.min(arr)
%timeit my_min0(arr)
%timeit my_min1(arr)
%timeit my_min2(arr)
結果
np.min(arr) : 103 μs ± 5.26 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
my_min0(arr) : 1.15 ms ± 40.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
my_min1(arr) : 1.13 ms ± 9.65 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
my_min2(arr) : 268 μs ± 16.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Numbaのnp.min
のまとめ
多分Numba内でのnp.min
は
@njit
def my_min1(arr):
return_min = np.finfo(np.float64).max
for i in arr.flat:
if return_min>i:
return_min = i
return return_min
と同じだよねと思って測ってみた感じ全く同じタイムが出た。
何回かarrをいじってみたがほぼ同じタイムなので、まあ一緒なんだろうと思う。
またもnumpyが尋常じゃなく速い。すごいなあ。
今度は並列化でも超えられなかった。というか並列化のコストがでかい。
Numbaのnp.max
流石にminの逆だろうと思ってやっていない。