0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Numbaは再帰関数もコンパイル結果のキャッシュやAOT (事前) コンパイルできる

Last updated at Posted at 2023-12-22

はじめに

Numba で再帰関数を高速化する場合、コンパイル結果のキャッシュやAOT(事前)コンパイルができず、プログラム実行の度にコンパイル時間がとられるという制限があるらしい。

ところが挙動を試していたらキャッシュは特に工夫なくともできていたし、AOTは少しの工夫で成功してしまったので紹介する。

変更履歴

(2023-12-24) キャッシュ効果の測定コードと測定結果を追記

コンパイル結果のキャッシュ

方法

特に工夫せず cache=True するだけでできた。

  • @jit に型指定しない場合 (初回呼び出し時にコンパイル)
from numba import jit

# 再帰関数に @jit、cache=True
@jit(nopython=True, cache=True)
def rec_fib(n):
    if n == 0 or n == 1:
        return n
    return rec_fib(n - 1) + rec_fib(n - 2)

# 初回呼び出しでコンパイル or キャッシュ読込み
print(rec_fib(6))
  • @jit に型指定する場合 (関数定義時にコンパイル)
from numba import jit, boolean, int64, float64, typeof

# 再帰関数に @jit、型指定 + cache=True
@jit(int64(int64), nopython=True, cache=True)
def rec_fib(n):
    if n == 0 or n == 1:
        return n
    return rec_fib(n - 1) + rec_fib(n - 2)
# 関数定義時にコンパイル or キャッシュ読込み

print(rec_fib(6))

オーバーヘッド時間の測定結果はかなり長くなってしまったので後述。

なお、Numbaバージョンが最新の場合 (0.57.0) でも古めの場合 (0.48.0) でも同様の挙動だった。

AOT(事前)コンパイル

AOTは再帰関数に対して普通に試みるとエラーになる。

from numba.pycc import CC

cc = CC("my_module_rec")

@cc.export("rec_fib", "int64(int64)")
def rec_fib(n):
    if n == 0 or n == 1:
        return n
    return rec_fib(n - 1) + rec_fib(n - 2)


print("start-compile: rec_fib")
cc.compile()    # エラー!
print("finish-compile: rec_fib")

AOT成功コード例

以下のようにするとAOTに成功した。

from numba import jit
from numba.pycc import CC

cc = CC("my_module_rec")


@jit(nopython=True)    # (a)
def _rec_fib(n):
    if n == 0 or n == 1:
        return n
    return _rec_fib(n - 1) + _rec_fib(n - 2)

@cc.export("rec_fib_caller", "int64(int64)")    # (c)
def rec_fib_caller(n):    # (b)
    return _rec_fib(n)


print("start-compile: rec_fib_caller")
cc.compile()
print("finish-compile: rec_fib_caller")


print("\nコンパイル結果の動作確認")
import my_module_rec
print(f"{my_module_rec.rec_fib_caller(20) =}", "\n成功!!")
  • 工夫する点は以下の通り
    • 再帰関数には @jit するが cc.export しない (a)
    • 再帰関数を呼び出すだけの別の関数を用意する (b)
    • 呼び出し用関数を cc.export してコンパイルする (c)

AOTなのに @jit とは一体?、と思うだろうがこれが鍵だった。
これを外したときのエラーメッセージをから推測するに、cc.export した関数のコンパイル過程で内部で呼び出された関数があった時、JITコンパイル時と同様、呼び出し対象の関数が @jit されていればコンパイル結果に含まれるのではないかと思う。

こちらもNumbaバージョンが最新の場合 (0.57.0) でも古めの場合 (0.48.0) でも同様の挙動だった。

キャッシュの効果の検証測定

Numba のオーバーヘッドが複雑なので真面目に検証したらかなり長くなってしまいました。
ここからは完全に興味のある人向け。

型指定なしの場合

キャッシュによるオーバーヘッド削減効果の挙動を知るため、まず基本形として、非再帰のNumba化関数を3つ ( f0, f1, f2 ) 含む以下のコードを複数回実行した。

コンパイル時間が分かりやすくなるよう、関数定義では不要な分岐を追加してでたらめな処理を増やし、一方で関数呼び出しでは関数の中での計算が一瞬で終わるような値を与えている。

import sys, time

TIMES = []
TIMES.append(("start", time.perf_counter()))

import numpy as np
TIMES.append(("import numpy", time.perf_counter()))

import numba
from numba import jit
TIMES.append(("import numba", time.perf_counter()))

print("Python: ", sys.version)
print("Numpy: ", np.__version__)
print("Numba: ", numba.__version__)
print()
TIMES.append(("print_versions", time.perf_counter()))


@jit(nopython=True, cache=True)
def f0(n):
    if n >= 0:
        return n + 1
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[0])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f0: def", time.perf_counter()))

r00 = f0(100)
TIMES.append(("f0: 1st call", time.perf_counter()))
r01 = f0(101)
TIMES.append(("f0: 2nd call", time.perf_counter()))


@jit(nopython=True, cache=True)
def f1(n):
    if n >= 0:
        return n - 1
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[1])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f1: def", time.perf_counter()))

r10 = f1(10)
TIMES.append(("f1: 1st call", time.perf_counter()))
r11 = f1(11)
TIMES.append(("f1: 2nd call", time.perf_counter()))

@jit(nopython=True, cache=True)
def f2(n):
    if n >= 0:
        return n * 2
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[2])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f2: def", time.perf_counter()))

r20 = f2(20)
TIMES.append(("f2: 1st call", time.perf_counter()))
r21 = f2(21)
TIMES.append(("f2: 2nd call", time.perf_counter()))



ELAPSED = [(TIMES[i][0], TIMES[i][1] - TIMES[i - 1][1]) for i in range(1, len(TIMES))]
for result in ELAPSED:
    print(result[0], "\t{:.6f}".format(result[1]))

print("\n", r00, r01, r10, r11, r20, r21)

測定結果 (型指定なし、非再帰)

  • 単位:sec (Python 3.11.4, Numpy 1.24.1, Numba 0.57.0)
プログラム実行1回目
(キャッシュ未作成)
プログラム実行2回目
(キャッシュ作成済み)
f0関数定義 0.024984 0.004102
f0呼び出し1回目 1.752147 0.141097
f0呼び出し2回目 0.000003 0.000003
f1関数定義 0.000335 0.000381
f1呼び出し1回目 0.286833 0.005060
f1呼び出し2回目 0.000003 0.000002
f2関数定義 0.000322 0.000357
f2呼び出し1回目 0.290225 0.004915
f2呼び出し2回目 0.000003 0.000001

この結果からオーバーヘッドの出方を整理すると以下のようになる。

  • キャッシュ未作成時、プログラム実行1回目の各関数の初回呼び出しで大きなオーバーヘッド発生 (コンパイル)
  • キャッシュ未作成時、一度のプログラム実行内で複数の関数を扱うと、1つ目の関数と2つ目以降の関数でオーバーヘッドの大きさが異なる
  • キャッシュ利用時は、1つ目の関数の初回呼び出しは100ミリ秒単位 (おそらくキャッシュファイルの読み込み)
  • キャッシュ利用時は、2つ目以降の関数の初回呼び出しが数ミリ秒単位
  • 一度のプログラム実行内で同一の関数を複数回呼び出すと、2回目以降の呼び出しはマイクロ秒単位

次に、再帰関数でもキャッシュが効くかを確かめるため、上記のコードから f2 を再帰関数 f2_rec に置き換えた以下のコードを測定した。

import sys, time

TIMES = []
TIMES.append(("start", time.perf_counter()))

import numpy as np
TIMES.append(("import numpy", time.perf_counter()))

import numba
from numba import jit
TIMES.append(("import numba", time.perf_counter()))

print("Python: ", sys.version)
print("Numpy: ", np.__version__)
print("Numba: ", numba.__version__)
print()
TIMES.append(("print_versions", time.perf_counter()))


@jit(nopython=True, cache=True)
def f0(n):
    if n >= 0:
        return n + 1
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[0])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f0: def", time.perf_counter()))

r00 = f0(100)
TIMES.append(("f0: 1st call", time.perf_counter()))
r01 = f0(101)
TIMES.append(("f0: 2nd call", time.perf_counter()))


@jit(nopython=True, cache=True)
def f1(n):
    if n >= 0:
        return n - 1
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[1])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f1: def", time.perf_counter()))

r10 = f1(10)
TIMES.append(("f1: 1st call", time.perf_counter()))
r11 = f1(11)
TIMES.append(("f1: 2nd call", time.perf_counter()))

@jit(nopython=True, cache=True)
def f2_rec(n):
    if n == 0 or n == 1:
        return n
    if n > 1:
        return f2_rec(n - 1) + f2_rec(n - 2)
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[2])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f2_rec: def", time.perf_counter()))

r20 = f2_rec(20)
TIMES.append(("f2_rec: 1st call", time.perf_counter()))
r21 = f2_rec(21)
TIMES.append(("f2_rec: 2nd call", time.perf_counter()))



ELAPSED = [(TIMES[i][0], TIMES[i][1] - TIMES[i - 1][1]) for i in range(1, len(TIMES))]
for result in ELAPSED:
    print(result[0], "\t{:.6f}".format(result[1]))

print("\n", r00, r01, r10, r11, r20, r21)

測定結果 (型指定なし、再帰あり)

  • 単位:sec (Python 3.11.4, Numpy 1.24.1, Numba 0.57.0)
プログラム実行1回目
(キャッシュなし)
プログラム実行2回目
(キャッシュ作成済み)
f0関数定義 0.003905 0.003968
f0呼び出し1回目 1.511317 0.135252
f0呼び出し2回目 0.000003 0.000003
f1関数定義 0.000286 0.000358
f1呼び出し1回目 0.273137 0.004888
f1呼び出し2回目 0.000003 0.000002
f2_rec関数定義 0.000285 0.000318
f2_rec呼び出し1回目 0.334119 0.004873
f2_rec呼び出し2回目 0.000130 0.000130

キャッシュ利用時のf2_recの初回呼び出しオーバーヘッドはミリ秒単位と非再帰の場合と同等。
つまり、再帰関数でも関係なくキャッシュできている。

なお、Numbaバージョンが古い場合 (Numba 0.48.0) でも測定したが同様の挙動だった。

型指定ありの場合

型指定ありでも測定した。
まずは非再帰関数3つを含むコード。

型指定あり非再帰関数の測定コード (クリックで展開/折りたたみ)
import sys, time

TIMES = []
TIMES.append(("start", time.perf_counter()))

import numpy as np
TIMES.append(("import numpy", time.perf_counter()))

import numba
from numba import jit, boolean, int64, float64, typeof
TIMES.append(("import numba", time.perf_counter()))

print("Python: ", sys.version)
print("Numpy: ", np.__version__)
print("Numba: ", numba.__version__)
print()
TIMES.append(("print_versions", time.perf_counter()))


@jit(int64(int64), nopython=True, cache=True)
def f0(n):
    if n >= 0:
        return n + 1
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[0])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f0: def", time.perf_counter()))

r00 = f0(100)
TIMES.append(("f0: 1st call", time.perf_counter()))
r01 = f0(101)
TIMES.append(("f0: 2nd call", time.perf_counter()))


@jit(int64(int64), nopython=True, cache=True)
def f1(n):
    if n >= 0:
        return n - 1
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[1])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f1: def", time.perf_counter()))

r10 = f1(10)
TIMES.append(("f1: 1st call", time.perf_counter()))
r11 = f1(11)
TIMES.append(("f1: 2nd call", time.perf_counter()))

@jit(int64(int64), nopython=True, cache=True)
def f2(n):
    if n >= 0:
        return n * 2
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[2])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f2: def", time.perf_counter()))

r20 = f2(20)
TIMES.append(("f2: 1st call", time.perf_counter()))
r21 = f2(21)
TIMES.append(("f2: 2nd call", time.perf_counter()))



ELAPSED = [(TIMES[i][0], TIMES[i][1] - TIMES[i - 1][1]) for i in range(1, len(TIMES))]
for result in ELAPSED:
    print(result[0], "\t{:.6f}".format(result[1]))

print("\n", r00, r01, r10, r11, r20, r21)

--- 折りたたみここまで ---

測定結果 (型指定あり、非再帰)

  • 単位:sec (Python 3.11.4, Numpy 1.24.1, Numba 0.57.0)
プログラム実行1回目
(キャッシュなし)
プログラム実行2回目
(キャッシュ作成済み)
f0関数定義 1.932430 0.149574
f0呼び出し1回目 0.000006 0.000004
f0呼び出し2回目 0.000001 0.000001
f1関数定義 0.288116 0.005622
f1呼び出し1回目 0.000005 0.000003
f1呼び出し2回目 0.000001 0.000001
f2関数定義 0.290771 0.005422
f2呼び出し1回目 0.000005 0.000003
f2呼び出し2回目 0.000001 0.000001

オーバーヘッドが発生するタイミングは型指定なしでは初回呼び出し時だったのが、型指定ありでは関数定義時に移動した。
オーバーヘッドの大きさは同等だった。(少し数値が異なるが繰り返し測定でのバラツキの方が大きい)

次に、再帰関数でもキャッシュが効くかを確かめるため、上記のコードから f2 を再帰関数 f2_rec に置き換えた。

型指定あり再帰関数の測定コード (クリックで展開/折りたたみ)
import sys, time

TIMES = []
TIMES.append(("start", time.perf_counter()))

import numpy as np
TIMES.append(("import numpy", time.perf_counter()))

import numba
from numba import jit, boolean, int64, float64, typeof
TIMES.append(("import numba", time.perf_counter()))

print("Python: ", sys.version)
print("Numpy: ", np.__version__)
print("Numba: ", numba.__version__)
print()
TIMES.append(("print_versions", time.perf_counter()))


@jit(int64(int64), nopython=True, cache=True)
def f0(n):
    if n >= 0:
        return n + 1
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[0])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f0: def", time.perf_counter()))

r00 = f0(100)
TIMES.append(("f0: 1st call", time.perf_counter()))
r01 = f0(101)
TIMES.append(("f0: 2nd call", time.perf_counter()))


@jit(int64(int64), nopython=True, cache=True)
def f1(n):
    if n >= 0:
        return n - 1
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[1])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f1: def", time.perf_counter()))

r10 = f1(10)
TIMES.append(("f1: 1st call", time.perf_counter()))
r11 = f1(11)
TIMES.append(("f1: 2nd call", time.perf_counter()))

@jit(int64(int64), nopython=True, cache=True)
def f2_rec(n):
    if n == 0 or n == 1:
        return n
    if n > 1:
        return f2_rec(n - 1) + f2_rec(n - 2)
    # この後はコンパイル時間を増やすための無駄なコード
    arr = np.arange(0, 15, dtype=np.int64).reshape((3, 5))
    np.sort(arr[2])
    for i in range(arr.shape[0]): arr[i] += i
    return arr[0][0] + arr[1][1]

TIMES.append(("f2_rec: def", time.perf_counter()))

r20 = f2_rec(20)
TIMES.append(("f2_rec: 1st call", time.perf_counter()))
r21 = f2_rec(21)
TIMES.append(("f2_rec: 2nd call", time.perf_counter()))



ELAPSED = [(TIMES[i][0], TIMES[i][1] - TIMES[i - 1][1]) for i in range(1, len(TIMES))]
for result in ELAPSED:
    print(result[0], "\t{:.6f}".format(result[1]))

print("\n", r00, r01, r10, r11, r20, r21)

--- 折りたたみここまで ---

測定結果 (型指定あり、再帰あり)

  • 単位:sec (Python 3.11.4, Numpy 1.24.1, Numba 0.57.0)
プログラム実行1回目
(キャッシュなし)
プログラム実行2回目
(キャッシュ作成済み)
f0関数定義 1.843977 0.138856
f0呼び出し1回目 0.000006 0.000004
f0呼び出し2回目 0.000001 0.000001
f1関数定義 0.287179 0.005209
f1呼び出し1回目 0.000005 0.000002
f1呼び出し2回目 0.000001 0.000001
f2_rec関数定義 0.350976 0.005049
f2_rec呼び出し1回目 0.000088 0.000081
f2_rec呼び出し2回目 0.000132 0.000133

キャッシュ利用時のf2_recの初回呼び出しオーバーヘッドはミリ秒単位と非再帰の場合と同等。
つまり、再帰関数でも関係なくキャッシュできている。

なお、Numbaバージョンが古い場合 (Numba 0.48.0) でも測定したが同様の挙動だった。

キャッシュ効果の測定まとめ

再帰関数でも非再帰関数でもキャッシュ利用により初回オーバーヘッドは同等に小さくなった。

まとめ

実は再帰関数もキャッシュやAOTコンパイルができた。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?