Numbaでベストパフォーマンスを出す
numbaは@njit
で簡単に速くなりますが、色々と詰めることでフォートランなどといった言語に劣らないぐらい速くなります(ケースによっては)。
たまたま速くする記事みたいなものを見つけて、深掘りしてみたので、メモしておきます。
なおタイトルに(1)とつけてますが、(2)があるかはわかりません。
先に結論だけ
新しく配列を作ろうと思うと途端に時間がかかる。
ので以下のようなことに気をつけて書く必要がある
-
arr = np.empty()
のような配列をつくるのはforループの浅いところで -
np.sum((v1-v2)**2)
は新しく配列を作ってしまう(v1[:] = (v1-v2)**2
もだめ) - 以下のように
for
を使用して、どうにかする
def my_sum(v1, v2):
var = 0.
for i in range(len(v1)):
var += (v1[i]-v2[i])**2
return var
とか、破壊的な計算でいいなら
def my_sum(v1, v2):
for i in range(len(v1)):
v1[i] = (v1[i]-v2[i])**2
return np.sum(v1)
-
for
でやるより遅くなりがちだが、破壊的な計算でいいならv1-=v2
もok
また、下では元記事にある他の高速化も試しているが、以下は特に効果がなかった。
-
a.shape[0]
等を使用しても遅くならない -
np.sum
は別に遅くならない
スペック
M1 macbookair, macOS=15.2(24C101), conda=23.7.4(あんま関係ないか)
Python=3.10.14, numba=0.60.0, numpy=1.26.4
元記事
以上の記事では、大雑把に「numbaがフォートランよりおせえ!どうにかならない?」といった内容で話しています。
なんの目的のコードなのかはよく見ていないので、割愛。
この記事の内容について確かめていきます
ベースのコード
メモで、すでに高速化用に少し手を加えているようなので、それを推測して戻しています。
また、元ページではrs = np.random.rand(d2, d2, d3)
が明らかにミスです。
正しくはrs = np.random.rand(d1, d2, d3)
で以下では直してます。
@njit
def pot(ra, rb):
f = np.sum((ra-rb)**2)
e = np.exp(f)
a = 1 - f
b = np.sqrt(f * e)
return a, b
@njit
def total(r, rs):
empty = np.array([1., 0.])
for i in range(rs.shape[0]):
for j in range(r.shape[0]):
temp = pot(r[j], rs[i,j])
empty[0] *= temp[0]
empty[1] += temp[1]
return empty
@njit
def runner(N):
d1 = 4
d2 = 8
d3 = 3
result = np.array([1., 0.])
r = np.zeros((d2, d3))
rs = np.zeros((d1, d2, d3))
for i in range(N):
r = np.random.rand(d2, d3)
rs = np.random.rand(d1, d2, d3)
temp = total(r, rs)
result[0] *= temp[0]
result[1] *= temp[1]
測定
確認用コード(折りたたみ)
t_start = time.time()
runner(1)
t_end = time.time()
run_time = t_end - t_start
print(f"Compiled in {round(run_time, 3)} seconds.")
%timeit runner(1000000)
実行してみると、
Compiled in 0.737 seconds.
1.97 s ± 73.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
変更点0(というか修正)
まず初めに、やりたいであろうことができていない点の修正。runner
の中で
def runner(N):
r = np.zeros((d2, d3))
rs = np.zeros((d1, d2, d3))
for i in range(N):
r = np.random.rand(d2, d3)
rs = np.random.rand(d2, d2, d3)
上記の部分では、pythonやっていればわかる人も多いと思うが、変数r
を毎度新たに定義し直している。元記事の中では以下のような関数を作ってランダム埋めをしています。
@njit
def random_inplace(a):
for i in range(a.size):
a.flat[i] = np.random.rand()
ってか別にランダムな配列を作る時間なんぞ欲しいわけではないので、外で定義しとけよと思い修正しました。
確認用コード(折りたたみ)
@njit
def runner(arr_r,arr_rs):
d1 = arr_rs.shape[1]
d2 = arr_rs.shape[2]
d3 = arr_rs.shape[3]
result = np.array([1., 0.])
for i in range(arr_r.shape[0]):
temp = total(arr_r[i], arr_rs[i])
result[0] *= temp[0]
result[1] *= temp[1]
N=1
d1, d2, d3 = 4, 8, 3
arr_r, arr_rs = np.random.rand(N, d2, d3), np.random.rand(N, d1, d2, d3)
t_start = time.time()
runner(arr_r, arr_rs)
t_end = time.time()
run_time = t_end - t_start
print(f"Compiled in {round(run_time, 3)} seconds.")
N=1000000
arr_r, arr_rs = np.random.rand(N, d2, d3), np.random.rand(N, d1, d2, d3)
%timeit runner(arr_r, arr_rs)
測定
Compiled in 0.667 seconds.
1.46 s ± 37.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
これをベースに考えていく。
変更点1
途中でr.shape[0]
といった部分が出てくる。これを使うと(元コードのコメント曰く)遅くなるらしい。
ので、total(r, rs, d1, d2):
といったようにそれぞれを引数として与える形に修正。
確認用コード(折りたたみ)
@njit
def pot(ra, rb):
f = np.sum((ra-rb)**2)
e = np.exp(f)
a = 1 - f
b = np.sqrt(f * e)
return a, b
@njit
- def total(r, rs): # ベース
+ def total(r, rs, d1, d2): # 変更点1
empty = np.array([1., 0.])
- for i in range(rs.shape[0]): # ベース
+ for i in range(d1): # 変更点1
- for j in range(r.shape[0]): # ベース
+ for j in range(ds): # 変更点1
temp = pot(r[j], rs[i,j])
empty[0] *= temp[0]
empty[1] += temp[1]
return empty
@njit
def runner(arr_r,arr_rs):
d1 = arr_rs.shape[1]
d2 = arr_rs.shape[2]
d3 = arr_rs.shape[3]
result = np.array([1., 0.])
for i in range(arr_r.shape[0]):
- temp = total(arr_r[i], arr_rs[i]) # ベース
+ temp = total(arr_r[i], arr_rs[i], d1, d2) # 変更点1
result[0] *= temp[0]
result[1] *= temp[1]
測定
Compiled in 0.694 seconds.
1.81 s ± 37.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
変わらず
まあ、ここ変わられると今までの全てのコードでshape使っているので困ってしまうし、まあええか
変更点2
total
の途中でempty = np.array([1., 0.])
が出てくる。別にnumpy.arrayである必要はないでしょと
ので、
確認用コード(折りたたみ)
@njit
def total(r, rs):
- empty = np.array([1., 0.])
+ empty0, empty1 = 1., 0.
for i in range(rs.shape[0]):
for j in range(r.shape[0]):
temp = pot(r[j], rs[i,j])
- empty[0] *= temp[0]
- empty[1] += temp[1]
- return empty
+ empty0 *= temp[0]
+ empty1 += temp[1]
+ return (empty0, empty1)
測定
Compiled in 0.646 seconds.
1.74 s ± 2.96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
変わらず
変更点3
pot
の途中でnp.sum((ra-rb)**2)
が出てくる。これを使うと(コメント曰く)遅くなるらしい。
ので、
@njit
def my_sum(v1, v2):
v0 = 0.
for i in range(len(v1)):
v0 += (v1[i]-v2[i])**2
return v0
こんな関数を用意
確認用コード(折りたたみ)
@njit
def pot(ra, rb):
- f = np.sum((ra-rb)**2)
+ f = my_sum(ra,rb)
e = np.exp(f)
a = 1 - f
b = np.sqrt(f * e)
return a, b
測定
Compiled in 0.52 seconds.
232 ms ± 8.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
爆速、とはいえこんなこと毎回やってらんねえよ、つらいねえ
中間結果
変更 | Base | 変更1 | 変更2 | 変更3 |
---|---|---|---|---|
内容 | shapeを消す | ndarrayを消す |
np.sum(()**2) を消す |
|
Time | 1.46 s ± 37.8 ms | 1.81 s ± 37.7 ms | 1.54 s ± 56.1 ms | 232 ms ± 8.72 ms |
明らかに変更3のみ効果あり
ほんの少し変わっているかもなので、変更3に1,2を付け加える形で見ていく。
変更点1+3
測定
Compiled in 0.534 seconds.
230 ms ± 10.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
変更点2+3
測定
Compiled in 0.468 seconds.
166 ms ± 4.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
変更点1~3全部
測定
Compiled in 0.461 seconds.
166 ms ± 2.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
中間結果
変更 | 変更3 | 変更1+3 | 変更2+3 | 変更1+2+3 |
---|---|---|---|---|
Time | 232 ms ± 8.72 ms | 230 ms ± 10.9 ms | 166 ms ± 4.62 ms | 166 ms ± 2.11 ms |
変更点1はほぼ影響なさそう、変更2は微小な変化があるっぽい
np.sum((ra-rb)**2)
の何がダメなのか掘り下げ
以下、内容がよくないです。改めて後で別ページで書き直すと思います。
変更2のコードから追加でpot
のf = np.sum((ra-rb)**2)
を変更して検証
変更3-1
np.sum
のみ置き換え、f = my_sum1((ra-rb)**2)
@njit
def my_sum1(arr):
return_num = 0.
for i in arr.flat:
return_num += i
return return_num
Compiled in 0.73 seconds.
1.76 s ± 17.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
変更3-2
(ra-rb)**2
のみ置き換え、f = my_sum2(ra,rb)
@njit
def my_sum2(v1, v2):
arr = np.empty(len(v1),dtype="float64")
for i in range(len(v1)):
arr[i] = (v1[i]-v2[i])**2
return np.sum(arr)
Compiled in 0.751 seconds.
1.77 s ± 28.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
変更3-3
np.sum
,(ra-rb)**2
どちらも置き換え、
ただし(ra-rb)**2
の計算結果を入れる場所としてarr
を用意し、そこに1つづつ入れる。
f = my_sum3(ra,rb)
@njit
def my_sum3(v1, v2):
arr = np.empty(len(v1),dtype="float64")
for i in range(len(v1)):
arr[i] = (v1[i]-v2[i])**2
return my_sum1(arr)
Compiled in 0.736 seconds.
1.66 s ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
変更3-4
配列を新たに定義せずに、v1
を上書きしていく形で保存していく。
f = my_sum4(ra,rb)
@njit
def my_sum4(v1, v2):
for i in range(len(v1)):
v1[i] = (v1[i]-v2[i])**2
return np.sum(v1)
Compiled in 0.758 seconds.
230 ms ± 770 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
変更3-5
np.sum(v1)
が速度低下なしなことはわかったので、それ以外のところでどうにかならないか検討
f = my_sum5(ra,rb)
@njit
def my_sum5(v1, v2):
v1[:] = (v1-v2)**2
return np.sum(v1)
Compiled in 0.705 seconds.
1.86 s ± 20.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
変更3-6
3-5をもっと捻くれた書き方で
f = my_sum6(ra,rb)
@njit
def my_sum6(v1, v2):
v1-=v2
v1**=2
return np.sum(v1)
Compiled in 0.735 seconds.
604 ms ± 16.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
変更3-7
変更3の時に使ってかなり速くなったmy_sum
にarr = np.empty(~)
をつける
f = my_sum7(ra,rb)
@njit
def my_sum7(v1, v2):
arr = np.empty(len(v1),dtype="float64")
v0 = 0.
for i in range(len(v1)):
v0 += (v1[i]-v2[i])**2
return v0
Compiled in 0.658 seconds.
1.77 s ± 15.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
結論
上記の結果からわかる通り、新しく配列を作ろうと思うと途端に時間がかかっていると思われる。
また、(v1-v2)**2
のようなブロードキャスト前提の書き方は実は配列を新たに作ってしまうみたい
v1-=v2
のような書き方は配列を作らないが、forでやるより遅くなる。
つまり、できるだけブロードキャストを使わずにforで書く必要があるみたい、いや無理だが?