はじめに
Pythonを高速化するNumbaにはクラスをNumba化するjitclass
があるが、jitclassはキャッシュやAOT (事前) コンパイルができないと思っていた。
そんな中、jitclassもキャッシュできるという情報を発見。
https://github.com/numba/numba/issues/4830#issuecomment-862424725
それどころかAOTもできるという情報も見つけた。
https://github.com/riantkb/typical90_python の L: 012、Q: 017
なにやら、Numba化された関数から@jitclass
したクラスを呼び出す構造にしておいて、元の関数をキャッシュもしくはAOTするとクラスの内容も保存されるとか。
本記事はjitclassのキャッシュ/AOTが可能か検証しただけだが、日本語どころか公式 (jitclass、AOT1、AOT2) にも情報が乏しいので共有しておく。
目次
- 実験用コード
- キャッシュとAOTの最低オーバーヘッド
- 追試: クラスのコンパイルを重くしてみる
- 注意: キャッシュを利用したい場合はjitclassの呼出をキャッシュ済みNumba化関数から行うこと
- おわり
実験用コード
サンプルとしてはリングバッファのようなものを作った。
手元 (Numba 0.53) とAtCoderのコードテスト (Numba 0.48) どちらでも動くようにしてある。本記事の測定結果はコードテストでのもの。
クラスのコンパイル時間と呼出元関数のそれとを分離する方法が思いつかなかったので、呼出元関数は極力短くした。コンパイル時間が分かりやすいようにJITでも型指定をつけ、計算は最小限の内容とした。
AtCoderにおけるNumba化関数のキャッシュとAOTについてはyniji氏の記事 「AtCoderで Python を高速化する Numpy + Numba を使う」を参照。
測定用コード-JIT用
N = 10
from time import perf_counter
measured = []
measured.append(('begin', perf_counter(), 0))
import numpy as np
from numba import jit, int64, typeof
try:
from numba.experimental import jitclass
except ImportError:
from numba import jitclass
measured.append(('import_libraries', perf_counter(), 1))
spec = [('_top', int64),
('_bottom', int64),
('_arraysize', int64),
('_dtype', typeof(np.int64)),
('_array', int64[:]),]
@jitclass(spec)
class Buffer():
def __init__(self):
self._top = 0
self._bottom = 0
self._arraysize = 4
self._dtype = np.int64
self._array = np.zeros(self._arraysize, dtype=self._dtype)
@property
def nholdings(self):
return self._top - self._bottom
def pop(self):
if self.nholdings <= 0:
raise Exception
ret = self._array[self._ind(self._top - 1)]
self._top -= 1
return ret
def popleft(self):
if self.nholdings <= 0:
raise Exception
ret = self._array[self._ind(self._bottom)]
self._bottom += 1
return ret
def append(self, value):
if self.nholdings >= self._arraysize:
self._extend()
self._array[self._ind(self._top)] = value
self._top += 1
return None
def _extend(self):
oldsize = self._arraysize
new = np.zeros(oldsize * 2, dtype=self._dtype)
if self._ind(self._bottom) < self._ind(self._top):
new[0: oldsize] = self._array
else:
new[self._ind(self._bottom): oldsize] = (
self._array[self._ind(self._bottom): oldsize])
new[oldsize: oldsize + self._ind(self._top)] = (
self._array[0: 0 + self._ind(self._top)])
self._array = new
self._arraysize *= 2
return None
def _ind(self, pos):
return pos % self._arraysize
measured.append(('define_jitclass', perf_counter(), 2))
'''切換え用
@jit(nopython=True) # キャッシュなし(型指定なし)
@jit(int64(int64), nopython=True, cache=False) # キャッシュなし(型指定あり)
@jit(int64(int64), nopython=True, cache=True) # キャッシュあり(型指定あり)
'''
@jit(int64(int64), nopython=True, cache=True) # キャッシュあり(型指定あり)
def func(x):
a = Buffer()
a.append(x * 10)
v = a.popleft()
return v
measured.append(('define_func', perf_counter(), 3))
ans = func(N)
measured.append(('call_func_1st', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_2nd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_3rd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_4th', perf_counter(), ans))
for i in range(len(measured)):
if i == 0:
continue
elapsedtime = str(round(measured[i][1] - measured[i - 1][1], 3))
print(elapsedtime, measured[i][0], measured[i][2], sep='\t')
i = 4
elapsedtime = str(round(measured[i][1] - measured[0][1], 3))
print(elapsedtime, 'from-begin to-' + measured[i][0], sep='\t')
elapsedtime = str(round(measured[-1][1] - measured[0][1], 3))
print(elapsedtime, 'total', sep='\t')
測定用コード-AOT用
N = 10
from time import perf_counter
measured = []
measured.append(('begin', perf_counter(), 0))
pass # import numpy as np # numpy のインポートもスキップ
measured.append(('import_libraries', perf_counter(), 1))
measured.append(('define_jitclass', perf_counter(), 2))
try:
from my_module import func
except ImportError:
print('---compiling...---')
import numpy as np
import numba
from numba import int64, typeof
try:
from numba.experimental import jitclass
except ImportError:
from numba import jitclass
from numba.pycc import CC
spec = [('_top', int64),
('_bottom', int64),
('_arraysize', int64),
('_dtype', typeof(np.int64)),
('_array', int64[:]),]
@jitclass(spec)
class Buffer():
def __init__(self):
self._top = 0
self._bottom = 0
self._arraysize = 4
self._dtype = np.int64
self._array = np.zeros(self._arraysize, dtype=self._dtype)
@property
def nholdings(self):
return self._top - self._bottom
def pop(self):
if self.nholdings <= 0:
raise Exception
ret = self._array[self._ind(self._top - 1)]
self._top -= 1
return ret
def popleft(self):
if self.nholdings <= 0:
raise Exception
ret = self._array[self._ind(self._bottom)]
self._bottom += 1
return ret
def append(self, value):
if self.nholdings >= self._arraysize:
self._extend()
self._array[self._ind(self._top)] = value
self._top += 1
return None
def _extend(self):
oldsize = self._arraysize
new = np.zeros(oldsize * 2, dtype=self._dtype)
if self._ind(self._bottom) < self._ind(self._top):
new[0: oldsize] = self._array
else:
new[self._ind(self._bottom): oldsize] = (
self._array[self._ind(self._bottom): oldsize])
new[oldsize: oldsize + self._ind(self._top)] = (
self._array[0: 0 + self._ind(self._top)])
self._array = new
self._arraysize *= 2
return None
def _ind(self, pos):
return pos % self._arraysize
def func(x): # AOTでは@jitを使わない
a = Buffer()
a.append(x * 10)
v = a.popleft()
return v
cc = CC('my_module')
cc.export('func', 'int64(int64)')(func)
cc.compile()
print('---compile end---')
exit(0)
else:
pass
measured.append(('define_func', perf_counter(), 3))
ans = func(N)
measured.append(('call_func_1st', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_2nd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_3rd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_4th', perf_counter(), ans))
for i in range(len(measured)):
if i == 0:
continue
elapsedtime = str(round(measured[i][1] - measured[i - 1][1], 3))
print(elapsedtime, measured[i][0], measured[i][2], sep='\t')
i = 4
elapsedtime = str(round(measured[i][1] - measured[0][1], 3))
print(elapsedtime, 'from-begin to-' + measured[i][0], sep='\t')
elapsedtime = str(round(measured[-1][1] - measured[0][1], 3))
print(elapsedtime, 'total', sep='\t')
測定結果
(単位: sec)
コード | ライブラリ インポート |
クラス定義 | 関数定義 (AOTは読込) |
初回呼出 | 2回目呼出 |
---|---|---|---|---|---|
JITキャッシュなし (型指定なし) |
0.325 | 0.001 | 0.000 | 1.187 | 0.000 |
JITキャッシュなし (型指定あり) |
0.326 | 0.001 | 1.191 | 0.000 | 0.000 |
JITキャッシュ使用 (型指定あり) |
0.329 | 0.001 | 0.078 | 0.000 | 0.000 |
AOT | 0.000 | 0.000 | 0.083 | 0.000 | 0.000 |
- 型指定の有無でコンパイルのタイミングが変わることに注意。
- キャッシュでコンパイル相当部分の時間が大幅に減った。おそらくクラスもキャッシュできたのだろう。
- AOTは構造からして呼出元関数のコンパイル結果にjitclassが同梱されなければコードが完走しないため、AOTでの成功は確定。
キャッシュ、AOTともに80msほど残るのが気になるので、キャッシュとAOTの最低オーバーヘッドも確認することにする。
キャッシュとAOTの最低オーバーヘッド
呼出元関数を、jitclassを呼び出さずただのNumpy配列を使う関数に置き換え比較用として測定した。
元のコードのjitclass呼出元関数
@jitclass(spec)
class Buffer():
メソッド定義...
@jit(int64(int64), nopython=True, cache=True)
def func(x):
a = Buffer() # jitclassの呼出
a.append(x * 10)
v = a.popleft()
return v
比較用の自作クラス不使用関数
''' クラス定義をコメントアウト
@jitclass(spec)
class Buffer():
メソッド定義...
'''
@jit(int64(int64), nopython=True, cache=True)
def func(x):
a = np.zeros(4, dtype=np.int64) # 自作クラスを使わない
a[1] = x * 10
v = a[1]
return v
比較用コード(自作クラス不使用)-JIT (クリックで展開/折りたたみ)
N = 10
from time import perf_counter
measured = []
measured.append(('begin', perf_counter(), 0))
import numpy as np
from numba import jit, int64, typeof
try:
from numba.experimental import jitclass
except ImportError:
from numba import jitclass
measured.append(('import_libraries', perf_counter(), 1))
''' # クラス定義をコメントアウト
spec = [('_top', int64),
('_bottom', int64),
('_arraysize', int64),
('_dtype', typeof(np.int64)),
('_array', int64[:]),]
@jitclass(spec)
class Buffer():
def __init__(self):
self._top = 0
self._bottom = 0
self._arraysize = 4
self._dtype = np.int64
self._array = np.zeros(self._arraysize, dtype=self._dtype)
@property
def nholdings(self):
return self._top - self._bottom
def pop(self):
if self.nholdings <= 0:
raise Exception
ret = self._array[self._ind(self._top - 1)]
self._top -= 1
return ret
def popleft(self):
if self.nholdings <= 0:
raise Exception
ret = self._array[self._ind(self._bottom)]
self._bottom += 1
return ret
def append(self, value):
if self.nholdings >= self._arraysize:
self._extend()
self._array[self._ind(self._top)] = value
self._top += 1
return None
def _extend(self):
oldsize = self._arraysize
new = np.zeros(oldsize * 2, dtype=self._dtype)
if self._ind(self._bottom) < self._ind(self._top):
new[0: oldsize] = self._array
else:
new[self._ind(self._bottom): oldsize] = (
self._array[self._ind(self._bottom): oldsize])
new[oldsize: oldsize + self._ind(self._top)] = (
self._array[0: 0 + self._ind(self._top)])
self._array = new
self._arraysize *= 2
return None
def _ind(self, pos):
return pos % self._arraysize
'''
measured.append(('define_jitclass', perf_counter(), 2))
'''切換え用
@jit(nopython=True) # キャッシュなし(型指定なし)
@jit(int64(int64), nopython=True, cache=False) # キャッシュなし(型指定あり)
@jit(int64(int64), nopython=True, cache=True) # キャッシュあり(型指定あり)
'''
@jit(int64(int64), nopython=True, cache=True) # キャッシュあり(型指定あり)
def func(x):
a = np.zeros(4, dtype=np.int64) # 自作クラスを使わない
a[1] = x * 10
v = a[1]
return v
measured.append(('define_func', perf_counter(), 3))
ans = func(N)
measured.append(('call_func_1st', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_2nd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_3rd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_4th', perf_counter(), ans))
for i in range(len(measured)):
if i == 0:
continue
elapsedtime = str(round(measured[i][1] - measured[i - 1][1], 3))
print(elapsedtime, measured[i][0], measured[i][2], sep='\t')
i = 4
elapsedtime = str(round(measured[i][1] - measured[0][1], 3))
print(elapsedtime, 'from-begin to-' + measured[i][0], sep='\t')
elapsedtime = str(round(measured[-1][1] - measured[0][1], 3))
print(elapsedtime, 'total', sep='\t')
--- 展開ここまで ---
比較用コード(自作クラス不使用)-AOT (クリックで展開/折りたたみ)
N = 10
from time import perf_counter
measured = []
measured.append(('begin', perf_counter(), 0))
pass # import numpy as np # numpy のインポートもスキップ
measured.append(('import_libraries', perf_counter(), 1))
measured.append(('define_jitclass', perf_counter(), 2))
try:
from my_module import func
except ImportError:
print('---compiling...---')
import numpy as np
import numba
from numba import int64, typeof
try:
from numba.experimental import jitclass
except ImportError:
from numba import jitclass
from numba.pycc import CC
'''
spec = [('_top', int64),
('_bottom', int64),
('_arraysize', int64),
('_dtype', typeof(np.int64)),
('_array', int64[:]),]
@jitclass(spec)
class Buffer():
def __init__(self):
self._top = 0
self._bottom = 0
self._arraysize = 4
self._dtype = np.int64
self._array = np.zeros(self._arraysize, dtype=self._dtype)
@property
def nholdings(self):
return self._top - self._bottom
def pop(self):
if self.nholdings <= 0:
raise Exception
ret = self._array[self._ind(self._top - 1)]
self._top -= 1
return ret
def popleft(self):
if self.nholdings <= 0:
raise Exception
ret = self._array[self._ind(self._bottom)]
self._bottom += 1
return ret
def append(self, value):
if self.nholdings >= self._arraysize:
self._extend()
self._array[self._ind(self._top)] = value
self._top += 1
return None
def _extend(self):
oldsize = self._arraysize
new = np.zeros(oldsize * 2, dtype=self._dtype)
if self._ind(self._bottom) < self._ind(self._top):
new[0: oldsize] = self._array
else:
new[self._ind(self._bottom): oldsize] = (
self._array[self._ind(self._bottom): oldsize])
new[oldsize: oldsize + self._ind(self._top)] = (
self._array[0: 0 + self._ind(self._top)])
self._array = new
self._arraysize *= 2
return None
def _ind(self, pos):
return pos % self._arraysize
'''
def func(x):
a = np.zeros(4, dtype=np.int64) # 自作クラスを使わない
a[1] = x * 10
v = a[1]
return v
cc = CC('my_module')
cc.export('func', 'int64(int64)')(func)
cc.compile()
print('---compile end---')
exit(0)
else:
pass
measured.append(('define_func', perf_counter(), 3))
ans = func(N)
measured.append(('call_func_1st', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_2nd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_3rd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_4th', perf_counter(), ans))
for i in range(len(measured)):
if i == 0:
continue
elapsedtime = str(round(measured[i][1] - measured[i - 1][1], 3))
print(elapsedtime, measured[i][0], measured[i][2], sep='\t')
i = 4
elapsedtime = str(round(measured[i][1] - measured[0][1], 3))
print(elapsedtime, 'from-begin to-' + measured[i][0], sep='\t')
elapsedtime = str(round(measured[-1][1] - measured[0][1], 3))
print(elapsedtime, 'total', sep='\t')
--- 展開ここまで ---
測定結果
(単位: sec)
コード | ライブラリ インポート |
クラス定義 | 関数定義 (AOTは読込) |
初回呼出 | 2回目呼出 |
---|---|---|---|---|---|
jitclass呼出、キャッシュなし (型指定あり) |
0.326 | 0.001 | 1.191 | 0.000 | 0.000 |
クラス不使用、キャッシュなし (型指定あり) |
0.311 | 0.000 | 0.107 | 0.000 | 0.000 |
jitclass呼出、キャッシュ使用 (型指定あり) |
0.329 | 0.001 | 0.078 | 0.000 | 0.000 |
クラス不使用、キャッシュ使用 (型指定あり) |
0.309 | 0.000 | 0.075 | 0.000 | 0.000 |
jitclass呼出、AOT | 0.000 | 0.000 | 0.083 | 0.000 | 0.000 |
クラス不使用、AOT | 0.000 | 0.000 | 0.082 | 0.000 | 0.000 |
自作クラス不使用の結果を見るとキャッシュやAOTでも関数部分のコンパイル結果の読込だけで80msほどかかることが分かる。
それをjitclass使用の結果から差し引くと、キャッシュ/AOTでクラス分のコンパイル時間もゼロになったとみなしてよさそう。
追試: クラスのコンパイルを重くしてみる
jitclassのメソッドに無駄な処理を追加し、クラスのコンパイルを重くしてみた。
np.sortがあるとコンパイルに時間がかかるらしい (参考) ので、使用されない分岐先をif文で作り追加した。
処理追加コード-JIT (クリックで展開/折りたたみ)
N = 10
from time import perf_counter
measured = []
measured.append(('begin', perf_counter(), 0))
import numpy as np
from numba import jit, int64, typeof
try:
from numba.experimental import jitclass
except ImportError:
from numba import jitclass
measured.append(('import_libraries', perf_counter(), 1))
spec = [('_top', int64),
('_bottom', int64),
('_arraysize', int64),
('_dtype', typeof(np.int64)),
('_array', int64[:]),]
@jitclass(spec)
class Buffer():
def __init__(self):
self._top = 0
self._bottom = 0
self._arraysize = 4
self._dtype = np.int64
self._array = np.zeros(self._arraysize, dtype=self._dtype)
@property
def nholdings(self):
return self._top - self._bottom
def pop(self):
if self.nholdings <= 0:
raise Exception
ret = self._array[self._ind(self._top - 1)]
self._top -= 1
return ret
def popleft(self):
if self.nholdings <= 0:
raise Exception
ret = self._array[self._ind(self._bottom)]
self._bottom += 1
return ret
def append(self, value):
if self.nholdings >= self._arraysize:
self._extend()
self._array[self._ind(self._top)] = value
self._top += 1
return None
def _extend(self):
if self.nholdings < -9999:
self._array = np.sort(self._array) # コンパイルを遅くする
oldsize = self._arraysize
new = np.zeros(oldsize * 2, dtype=self._dtype)
if self._ind(self._bottom) < self._ind(self._top):
new[0: oldsize] = self._array
else:
new[self._ind(self._bottom): oldsize] = (
self._array[self._ind(self._bottom): oldsize])
new[oldsize: oldsize + self._ind(self._top)] = (
self._array[0: 0 + self._ind(self._top)])
self._array = new
self._arraysize *= 2
return None
def _ind(self, pos):
return pos % self._arraysize
measured.append(('define_jitclass', perf_counter(), 2))
'''切換え用
@jit(nopython=True) # キャッシュなし(型指定なし)
@jit(int64(int64), nopython=True, cache=False) # キャッシュなし(型指定あり)
@jit(int64(int64), nopython=True, cache=True) # キャッシュあり(型指定あり)
'''
@jit(int64(int64), nopython=True, cache=True) # キャッシュあり(型指定あり)
def func(x):
a = Buffer()
a.append(x * 10)
v = a.popleft()
return v
measured.append(('define_func', perf_counter(), 3))
ans = func(N)
measured.append(('call_func_1st', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_2nd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_3rd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_4th', perf_counter(), ans))
for i in range(len(measured)):
if i == 0:
continue
elapsedtime = str(round(measured[i][1] - measured[i - 1][1], 3))
print(elapsedtime, measured[i][0], measured[i][2], sep='\t')
i = 4
elapsedtime = str(round(measured[i][1] - measured[0][1], 3))
print(elapsedtime, 'from-begin to-' + measured[i][0], sep='\t')
elapsedtime = str(round(measured[-1][1] - measured[0][1], 3))
print(elapsedtime, 'total', sep='\t')
--- 展開ここまで ---
測定結果
(単位: 秒)
コード | ライブラリ インポート |
クラス定義 | 関数定義 (AOTは読込) |
初回呼出 | 2回目呼出 |
---|---|---|---|---|---|
元のコード、キャッシュなし (型指定あり) |
0.326 | 0.001 | 1.191 | 0.000 | 0.000 |
追加後、キャッシュなし (型指定あり) |
0.327 | 0.001 | 2.090 | 0.000 | 0.000 |
元のコード、キャッシュ使用 (型指定あり) |
0.329 | 0.001 | 0.078 | 0.000 | 0.000 |
追加後、キャッシュ使用 (型指定あり) |
0.325 | 0.001 | 0.079 | 0.000 | 0.000 |
コード追加でキャッシュなしは900ms増加なのに対してキャッシュありは変化なし。jitclass分のキャッシュに成功していると確信してよいだろう。
注意: キャッシュを利用したい場合はjitclassの呼出をキャッシュ済みNumba化関数から行うこと
実行時にjitclassを素のPythonや非キャッシュのNumba化関数から呼び出してしまうとキャッシュが利用されない。jitclassのキャッシュは呼出元のNumba化関数のキャッシュに一体化されるようだ。
@jitclass(spec)
class Buffer():
''' メソッド...'''
@jit(nopython=True, cache=True)
def nbfunc(x):
a = Buffer() # jitclassをNumba化関数から呼ぶと、
''' 処理... ''' # クラスのコンパイル結果が関数のキャッシュに同梱・利用される
nbfunc()
def pyfunc(x):
a = Buffer() # jitclassを素Pythonから呼ぶと再コンパイルになる
''' 処理... '''
pyfunc()
@jitclass(spec)
class Buffer():
''' メソッド...'''
@jit(nopython=True, cache=True)
def nmbfunc(x):
a = Buffer()
''' 処理... '''
nmbfunc()
@jit(nopython=True)
def nmbfunc2():
a = Buffer()
''' 処理... '''
nmbfunc2() #それどころか別のNumba化関数で非キャッシュのものから呼び出しても再コンパイルになる
おわり
Numba化された関数から@jitclass
されたクラスを呼び出す構造で元の関数をキャッシュ/AOTすると、関数のコンパイル結果にjitclassのコンパイル結果も含められてキャッシュ/AOTされることを確かめた。
関数のコンパイル時間とクラスのコンパイル時間の分離ができないのでキャッシュ成否の確証を得るのに手間取った。もっと大規模なクラスでテストした方が楽だったかもしれない。
Numbaでクラスを使うのは色々大変だけど可能性が広がった。