はじめに
Numba で再帰関数を高速化する場合、コンパイル結果のキャッシュやAOT(事前)コンパイルができず、プログラム実行の度にコンパイル時間がとられるという制限があるらしい。
ところが挙動を試していたらキャッシュは特に工夫なくともできていたし、AOTは少しの工夫で成功してしまったので紹介する。
変更履歴
(2023-12-24) キャッシュ効果の測定コードと測定結果を追記
コンパイル結果のキャッシュ
方法
特に工夫せず cache=True
するだけでできた。
-
@jit
に型指定しない場合 (初回呼び出し時にコンパイル)
from numba import jit
# 再帰関数に @jit、cache=True
@jit(nopython=True, cache=True)
def rec_fib(n):
if n == 0 or n == 1:
return n
return rec_fib(n - 1) + rec_fib(n - 2)
# 初回呼び出しでコンパイル or キャッシュ読込み
print(rec_fib(6))
-
@jit
に型指定する場合 (関数定義時にコンパイル)
from numba import jit, boolean, int64, float64, typeof
# 再帰関数に @jit、型指定 + cache=True
@jit(int64(int64), nopython=True, cache=True)
def rec_fib(n):
if n == 0 or n == 1:
return n
return rec_fib(n - 1) + rec_fib(n - 2)
# 関数定義時にコンパイル or キャッシュ読込み
print(rec_fib(6))
オーバーヘッド時間の測定結果はかなり長くなってしまったので後述。
なお、Numbaバージョンが最新の場合 (0.57.0) でも古めの場合 (0.48.0) でも同様の挙動だった。
AOT(事前)コンパイル
AOTは再帰関数に対して普通に試みるとエラーになる。
from numba.pycc import CC
cc = CC("my_module_rec")
@cc.export("rec_fib", "int64(int64)")
def rec_fib(n):
if n == 0 or n == 1:
return n
return rec_fib(n - 1) + rec_fib(n - 2)
print("start-compile: rec_fib")
cc.compile() # エラー!
print("finish-compile: rec_fib")
AOT成功コード例
以下のようにするとAOTに成功した。
from numba import jit
from numba.pycc import CC
cc = CC("my_module_rec")
@jit(nopython=True) # (a)
def _rec_fib(n):
if n == 0 or n == 1:
return n
return _rec_fib(n - 1) + _rec_fib(n - 2)
@cc.export("rec_fib_caller", "int64(int64)") # (c)
def rec_fib_caller(n): # (b)
return _rec_fib(n)
print("start-compile: rec_fib_caller")
cc.compile()
print("finish-compile: rec_fib_caller")
print("\nコンパイル結果の動作確認")
import my_module_rec
print(f"{my_module_rec.rec_fib_caller(20) =}", "\n成功!!")
- 工夫する点は以下の通り
- 再帰関数には
@jit
するがcc.export
しない (a) - 再帰関数を呼び出すだけの別の関数を用意する (b)
- 呼び出し用関数を
cc.export
してコンパイルする (c)
- 再帰関数には
AOTなのに @jit
とは一体?、と思うだろうがこれが鍵だった。
これを外したときのエラーメッセージをから推測するに、cc.export
した関数のコンパイル過程で内部で呼び出された関数があった時、JITコンパイル時と同様、呼び出し対象の関数が @jit
されていればコンパイル結果に含まれるのではないかと思う。
こちらもNumbaバージョンが最新の場合 (0.57.0) でも古めの場合 (0.48.0) でも同様の挙動だった。
キャッシュの効果の検証測定
Numba のオーバーヘッドが複雑なので真面目に検証したらかなり長くなってしまいました。
ここからは完全に興味のある人向け。
型指定なしの場合
キャッシュによるオーバーヘッド削減効果の挙動を知るため、まず基本形として、非再帰のNumba化関数を3つ ( f0
, f1
, f2
) 含む以下のコードを複数回実行した。
コンパイル時間が分かりやすくなるよう、関数定義では不要な分岐を追加してでたらめな処理を増やし、一方で関数呼び出しでは関数の中での計算が一瞬で終わるような値を与えている。
import sys, time
TIMES = []
TIMES.append(("start", time.perf_counter()))
import numpy as np
TIMES.append(("import numpy", time.perf_counter()))
import numba
from numba import jit
TIMES.append(("import numba", time.perf_counter()))
print("Python: ", sys.version)
print("Numpy: ", np.__version__)
print("Numba: ", numba.__version__)
print()
TIMES.append(("print_versions", time.perf_counter()))
@jit(nopython=True, cache=True)
def f0(n):
if n >= 0:
return n + 1
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[0])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f0: def", time.perf_counter()))
r00 = f0(100)
TIMES.append(("f0: 1st call", time.perf_counter()))
r01 = f0(101)
TIMES.append(("f0: 2nd call", time.perf_counter()))
@jit(nopython=True, cache=True)
def f1(n):
if n >= 0:
return n - 1
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[1])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f1: def", time.perf_counter()))
r10 = f1(10)
TIMES.append(("f1: 1st call", time.perf_counter()))
r11 = f1(11)
TIMES.append(("f1: 2nd call", time.perf_counter()))
@jit(nopython=True, cache=True)
def f2(n):
if n >= 0:
return n * 2
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[2])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f2: def", time.perf_counter()))
r20 = f2(20)
TIMES.append(("f2: 1st call", time.perf_counter()))
r21 = f2(21)
TIMES.append(("f2: 2nd call", time.perf_counter()))
ELAPSED = [(TIMES[i][0], TIMES[i][1] - TIMES[i - 1][1]) for i in range(1, len(TIMES))]
for result in ELAPSED:
print(result[0], "\t{:.6f}".format(result[1]))
print("\n", r00, r01, r10, r11, r20, r21)
測定結果 (型指定なし、非再帰)
- 単位:sec (Python 3.11.4, Numpy 1.24.1, Numba 0.57.0)
プログラム実行1回目 (キャッシュ未作成) |
プログラム実行2回目 (キャッシュ作成済み) |
|
---|---|---|
f0関数定義 | 0.024984 | 0.004102 |
f0呼び出し1回目 | 1.752147 | 0.141097 |
f0呼び出し2回目 | 0.000003 | 0.000003 |
f1関数定義 | 0.000335 | 0.000381 |
f1呼び出し1回目 | 0.286833 | 0.005060 |
f1呼び出し2回目 | 0.000003 | 0.000002 |
f2関数定義 | 0.000322 | 0.000357 |
f2呼び出し1回目 | 0.290225 | 0.004915 |
f2呼び出し2回目 | 0.000003 | 0.000001 |
この結果からオーバーヘッドの出方を整理すると以下のようになる。
- キャッシュ未作成時、プログラム実行1回目の各関数の初回呼び出しで大きなオーバーヘッド発生 (コンパイル)
- キャッシュ未作成時、一度のプログラム実行内で複数の関数を扱うと、1つ目の関数と2つ目以降の関数でオーバーヘッドの大きさが異なる
- キャッシュ利用時は、1つ目の関数の初回呼び出しは100ミリ秒単位 (おそらくキャッシュファイルの読み込み)
- キャッシュ利用時は、2つ目以降の関数の初回呼び出しが数ミリ秒単位
- 一度のプログラム実行内で同一の関数を複数回呼び出すと、2回目以降の呼び出しはマイクロ秒単位
次に、再帰関数でもキャッシュが効くかを確かめるため、上記のコードから f2
を再帰関数 f2_rec
に置き換えた以下のコードを測定した。
import sys, time
TIMES = []
TIMES.append(("start", time.perf_counter()))
import numpy as np
TIMES.append(("import numpy", time.perf_counter()))
import numba
from numba import jit
TIMES.append(("import numba", time.perf_counter()))
print("Python: ", sys.version)
print("Numpy: ", np.__version__)
print("Numba: ", numba.__version__)
print()
TIMES.append(("print_versions", time.perf_counter()))
@jit(nopython=True, cache=True)
def f0(n):
if n >= 0:
return n + 1
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[0])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f0: def", time.perf_counter()))
r00 = f0(100)
TIMES.append(("f0: 1st call", time.perf_counter()))
r01 = f0(101)
TIMES.append(("f0: 2nd call", time.perf_counter()))
@jit(nopython=True, cache=True)
def f1(n):
if n >= 0:
return n - 1
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[1])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f1: def", time.perf_counter()))
r10 = f1(10)
TIMES.append(("f1: 1st call", time.perf_counter()))
r11 = f1(11)
TIMES.append(("f1: 2nd call", time.perf_counter()))
@jit(nopython=True, cache=True)
def f2_rec(n):
if n == 0 or n == 1:
return n
if n > 1:
return f2_rec(n - 1) + f2_rec(n - 2)
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[2])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f2_rec: def", time.perf_counter()))
r20 = f2_rec(20)
TIMES.append(("f2_rec: 1st call", time.perf_counter()))
r21 = f2_rec(21)
TIMES.append(("f2_rec: 2nd call", time.perf_counter()))
ELAPSED = [(TIMES[i][0], TIMES[i][1] - TIMES[i - 1][1]) for i in range(1, len(TIMES))]
for result in ELAPSED:
print(result[0], "\t{:.6f}".format(result[1]))
print("\n", r00, r01, r10, r11, r20, r21)
測定結果 (型指定なし、再帰あり)
- 単位:sec (Python 3.11.4, Numpy 1.24.1, Numba 0.57.0)
プログラム実行1回目 (キャッシュなし) |
プログラム実行2回目 (キャッシュ作成済み) |
|
---|---|---|
f0関数定義 | 0.003905 | 0.003968 |
f0呼び出し1回目 | 1.511317 | 0.135252 |
f0呼び出し2回目 | 0.000003 | 0.000003 |
f1関数定義 | 0.000286 | 0.000358 |
f1呼び出し1回目 | 0.273137 | 0.004888 |
f1呼び出し2回目 | 0.000003 | 0.000002 |
f2_rec関数定義 | 0.000285 | 0.000318 |
f2_rec呼び出し1回目 | 0.334119 | 0.004873 |
f2_rec呼び出し2回目 | 0.000130 | 0.000130 |
キャッシュ利用時のf2_recの初回呼び出しオーバーヘッドはミリ秒単位と非再帰の場合と同等。
つまり、再帰関数でも関係なくキャッシュできている。
なお、Numbaバージョンが古い場合 (Numba 0.48.0) でも測定したが同様の挙動だった。
型指定ありの場合
型指定ありでも測定した。
まずは非再帰関数3つを含むコード。
型指定あり非再帰関数の測定コード (クリックで展開/折りたたみ)
import sys, time
TIMES = []
TIMES.append(("start", time.perf_counter()))
import numpy as np
TIMES.append(("import numpy", time.perf_counter()))
import numba
from numba import jit, boolean, int64, float64, typeof
TIMES.append(("import numba", time.perf_counter()))
print("Python: ", sys.version)
print("Numpy: ", np.__version__)
print("Numba: ", numba.__version__)
print()
TIMES.append(("print_versions", time.perf_counter()))
@jit(int64(int64), nopython=True, cache=True)
def f0(n):
if n >= 0:
return n + 1
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[0])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f0: def", time.perf_counter()))
r00 = f0(100)
TIMES.append(("f0: 1st call", time.perf_counter()))
r01 = f0(101)
TIMES.append(("f0: 2nd call", time.perf_counter()))
@jit(int64(int64), nopython=True, cache=True)
def f1(n):
if n >= 0:
return n - 1
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[1])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f1: def", time.perf_counter()))
r10 = f1(10)
TIMES.append(("f1: 1st call", time.perf_counter()))
r11 = f1(11)
TIMES.append(("f1: 2nd call", time.perf_counter()))
@jit(int64(int64), nopython=True, cache=True)
def f2(n):
if n >= 0:
return n * 2
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[2])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f2: def", time.perf_counter()))
r20 = f2(20)
TIMES.append(("f2: 1st call", time.perf_counter()))
r21 = f2(21)
TIMES.append(("f2: 2nd call", time.perf_counter()))
ELAPSED = [(TIMES[i][0], TIMES[i][1] - TIMES[i - 1][1]) for i in range(1, len(TIMES))]
for result in ELAPSED:
print(result[0], "\t{:.6f}".format(result[1]))
print("\n", r00, r01, r10, r11, r20, r21)
--- 折りたたみここまで ---
測定結果 (型指定あり、非再帰)
- 単位:sec (Python 3.11.4, Numpy 1.24.1, Numba 0.57.0)
プログラム実行1回目 (キャッシュなし) |
プログラム実行2回目 (キャッシュ作成済み) |
|
---|---|---|
f0関数定義 | 1.932430 | 0.149574 |
f0呼び出し1回目 | 0.000006 | 0.000004 |
f0呼び出し2回目 | 0.000001 | 0.000001 |
f1関数定義 | 0.288116 | 0.005622 |
f1呼び出し1回目 | 0.000005 | 0.000003 |
f1呼び出し2回目 | 0.000001 | 0.000001 |
f2関数定義 | 0.290771 | 0.005422 |
f2呼び出し1回目 | 0.000005 | 0.000003 |
f2呼び出し2回目 | 0.000001 | 0.000001 |
オーバーヘッドが発生するタイミングは型指定なしでは初回呼び出し時だったのが、型指定ありでは関数定義時に移動した。
オーバーヘッドの大きさは同等だった。(少し数値が異なるが繰り返し測定でのバラツキの方が大きい)
次に、再帰関数でもキャッシュが効くかを確かめるため、上記のコードから f2
を再帰関数 f2_rec
に置き換えた。
型指定あり再帰関数の測定コード (クリックで展開/折りたたみ)
import sys, time
TIMES = []
TIMES.append(("start", time.perf_counter()))
import numpy as np
TIMES.append(("import numpy", time.perf_counter()))
import numba
from numba import jit, boolean, int64, float64, typeof
TIMES.append(("import numba", time.perf_counter()))
print("Python: ", sys.version)
print("Numpy: ", np.__version__)
print("Numba: ", numba.__version__)
print()
TIMES.append(("print_versions", time.perf_counter()))
@jit(int64(int64), nopython=True, cache=True)
def f0(n):
if n >= 0:
return n + 1
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[0])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f0: def", time.perf_counter()))
r00 = f0(100)
TIMES.append(("f0: 1st call", time.perf_counter()))
r01 = f0(101)
TIMES.append(("f0: 2nd call", time.perf_counter()))
@jit(int64(int64), nopython=True, cache=True)
def f1(n):
if n >= 0:
return n - 1
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[1])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f1: def", time.perf_counter()))
r10 = f1(10)
TIMES.append(("f1: 1st call", time.perf_counter()))
r11 = f1(11)
TIMES.append(("f1: 2nd call", time.perf_counter()))
@jit(int64(int64), nopython=True, cache=True)
def f2_rec(n):
if n == 0 or n == 1:
return n
if n > 1:
return f2_rec(n - 1) + f2_rec(n - 2)
# この後はコンパイル時間を増やすための無駄なコード
arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
np.sort(arr[2])
for i in range(arr.shape[0]): arr[i] += i
return arr[0][0] + arr[1][1]
TIMES.append(("f2_rec: def", time.perf_counter()))
r20 = f2_rec(20)
TIMES.append(("f2_rec: 1st call", time.perf_counter()))
r21 = f2_rec(21)
TIMES.append(("f2_rec: 2nd call", time.perf_counter()))
ELAPSED = [(TIMES[i][0], TIMES[i][1] - TIMES[i - 1][1]) for i in range(1, len(TIMES))]
for result in ELAPSED:
print(result[0], "\t{:.6f}".format(result[1]))
print("\n", r00, r01, r10, r11, r20, r21)
--- 折りたたみここまで ---
測定結果 (型指定あり、再帰あり)
- 単位:sec (Python 3.11.4, Numpy 1.24.1, Numba 0.57.0)
プログラム実行1回目 (キャッシュなし) |
プログラム実行2回目 (キャッシュ作成済み) |
|
---|---|---|
f0関数定義 | 1.843977 | 0.138856 |
f0呼び出し1回目 | 0.000006 | 0.000004 |
f0呼び出し2回目 | 0.000001 | 0.000001 |
f1関数定義 | 0.287179 | 0.005209 |
f1呼び出し1回目 | 0.000005 | 0.000002 |
f1呼び出し2回目 | 0.000001 | 0.000001 |
f2_rec関数定義 | 0.350976 | 0.005049 |
f2_rec呼び出し1回目 | 0.000088 | 0.000081 |
f2_rec呼び出し2回目 | 0.000132 | 0.000133 |
キャッシュ利用時のf2_recの初回呼び出しオーバーヘッドはミリ秒単位と非再帰の場合と同等。
つまり、再帰関数でも関係なくキャッシュできている。
なお、Numbaバージョンが古い場合 (Numba 0.48.0) でも測定したが同様の挙動だった。
キャッシュ効果の測定まとめ
再帰関数でも非再帰関数でもキャッシュ利用により初回オーバーヘッドは同等に小さくなった。
まとめ
実は再帰関数もキャッシュやAOTコンパイルができた。