LoginSignup
2
0

More than 1 year has passed since last update.

Numbaでjitclassのコンパイル結果をキャッシュやAOT (事前) コンパイルできることを確認した

Last updated at Posted at 2022-05-14

はじめに

Pythonを高速化するNumbaにはクラスをNumba化するjitclassがあるが、jitclassはキャッシュやAOT (事前) コンパイルができないと思っていた。

そんな中、jitclassもキャッシュできるという情報を発見。
https://github.com/numba/numba/issues/4830#issuecomment-862424725

それどころかAOTもできるという情報も見つけた。
https://github.com/riantkb/typical90_python の L: 012、Q: 017

なにやら、Numba化された関数から@jitclassしたクラスを呼び出す構造にしておいて、元の関数をキャッシュもしくはAOTするとクラスの内容も保存されるとか。

本記事はjitclassのキャッシュ/AOTが可能か検証しただけだが、日本語どころか公式 (jitclassAOT1AOT2) にも情報が乏しいので共有しておく。

目次

実験用コード

サンプルとしてはリングバッファのようなものを作った。
手元 (Numba 0.53) とAtCoderのコードテスト (Numba 0.48) どちらでも動くようにしてある。本記事の測定結果はコードテストでのもの。
クラスのコンパイル時間と呼出元関数のそれとを分離する方法が思いつかなかったので、呼出元関数は極力短くした。コンパイル時間が分かりやすいようにJITでも型指定をつけ、計算は最小限の内容とした。

AtCoderにおけるNumba化関数のキャッシュとAOTについてはyniji氏の記事 「AtCoderで Python を高速化する Numpy + Numba を使う」を参照。

測定用コード-JIT用

N = 10


from time import perf_counter
measured = []
measured.append(('begin', perf_counter(), 0))

import numpy as np
from numba import jit, int64, typeof
try:
    from numba.experimental import jitclass
except ImportError:
    from numba import jitclass
measured.append(('import_libraries', perf_counter(), 1))


spec = [('_top', int64),
        ('_bottom', int64),
        ('_arraysize', int64),
        ('_dtype', typeof(np.int64)),
        ('_array', int64[:]),]
@jitclass(spec)
class Buffer():
    def __init__(self):
        self._top = 0
        self._bottom = 0
        self._arraysize = 4
        self._dtype = np.int64
        self._array = np.zeros(self._arraysize, dtype=self._dtype)

    @property
    def nholdings(self):
        return self._top - self._bottom

    def pop(self):
        if self.nholdings <= 0:
            raise Exception
        ret = self._array[self._ind(self._top - 1)]
        self._top -= 1
        return ret

    def popleft(self):
        if self.nholdings <= 0:
            raise Exception
        ret = self._array[self._ind(self._bottom)]
        self._bottom += 1
        return ret

    def append(self, value):
        if self.nholdings >= self._arraysize:
            self._extend()
        self._array[self._ind(self._top)] = value
        self._top += 1
        return None

    def _extend(self):
        oldsize = self._arraysize
        new = np.zeros(oldsize * 2, dtype=self._dtype)
        if self._ind(self._bottom) < self._ind(self._top):
            new[0: oldsize] = self._array
        else:
            new[self._ind(self._bottom): oldsize] = (
                self._array[self._ind(self._bottom): oldsize])
            new[oldsize: oldsize + self._ind(self._top)] = (
                self._array[0: 0 + self._ind(self._top)])
        self._array = new
        self._arraysize *= 2
        return None

    def _ind(self, pos):
        return pos % self._arraysize

measured.append(('define_jitclass', perf_counter(), 2))

'''切換え用
@jit(nopython=True)    # キャッシュなし(型指定なし)
@jit(int64(int64), nopython=True, cache=False)    # キャッシュなし(型指定あり)
@jit(int64(int64), nopython=True, cache=True)    # キャッシュあり(型指定あり)
'''

@jit(int64(int64), nopython=True, cache=True)    # キャッシュあり(型指定あり)
def func(x):
    a = Buffer()
    a.append(x * 10)
    v = a.popleft()
    return v

measured.append(('define_func', perf_counter(), 3))


ans = func(N)
measured.append(('call_func_1st', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_2nd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_3rd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_4th', perf_counter(), ans))

for i in range(len(measured)):
    if i == 0:
        continue
    elapsedtime = str(round(measured[i][1] - measured[i - 1][1], 3))
    print(elapsedtime, measured[i][0], measured[i][2], sep='\t')
i = 4
elapsedtime = str(round(measured[i][1] - measured[0][1], 3))
print(elapsedtime, 'from-begin to-' + measured[i][0], sep='\t')
elapsedtime = str(round(measured[-1][1] - measured[0][1], 3))
print(elapsedtime, 'total', sep='\t')

測定用コード-AOT用

N = 10


from time import perf_counter
measured = []
measured.append(('begin', perf_counter(), 0))

pass # import numpy as np # numpy のインポートもスキップ

measured.append(('import_libraries', perf_counter(), 1))
measured.append(('define_jitclass', perf_counter(), 2))

try:
    from my_module import func
except ImportError:

    print('---compiling...---')

    import numpy as np
    import numba
    from numba import int64, typeof

    try:
        from numba.experimental import jitclass
    except ImportError:
        from numba import jitclass
    from numba.pycc import CC

    spec = [('_top', int64),
            ('_bottom', int64),
            ('_arraysize', int64),
            ('_dtype', typeof(np.int64)),
            ('_array', int64[:]),]
    @jitclass(spec)
    class Buffer():
        def __init__(self):
            self._top = 0
            self._bottom = 0
            self._arraysize = 4
            self._dtype = np.int64
            self._array = np.zeros(self._arraysize, dtype=self._dtype)

        @property
        def nholdings(self):
            return self._top - self._bottom

        def pop(self):
            if self.nholdings <= 0:
                raise Exception
            ret = self._array[self._ind(self._top - 1)]
            self._top -= 1
            return ret

        def popleft(self):
            if self.nholdings <= 0:
                raise Exception
            ret = self._array[self._ind(self._bottom)]
            self._bottom += 1
            return ret

        def append(self, value):
            if self.nholdings >= self._arraysize:
                self._extend()
            self._array[self._ind(self._top)] = value
            self._top += 1
            return None

        def _extend(self):
            oldsize = self._arraysize
            new = np.zeros(oldsize * 2, dtype=self._dtype)
            if self._ind(self._bottom) < self._ind(self._top):
                new[0: oldsize] = self._array
            else:
                new[self._ind(self._bottom): oldsize] = (
                    self._array[self._ind(self._bottom): oldsize])
                new[oldsize: oldsize + self._ind(self._top)] = (
                    self._array[0: 0 + self._ind(self._top)])
            self._array = new
            self._arraysize *= 2
            return None

        def _ind(self, pos):
            return pos % self._arraysize


    def func(x):    # AOTでは@jitを使わない
        a = Buffer()
        a.append(x * 10)
        v = a.popleft()
        return v


    cc = CC('my_module')
    cc.export('func', 'int64(int64)')(func)
    cc.compile()
    print('---compile end---')
    exit(0)
else:
    pass

measured.append(('define_func', perf_counter(), 3))


ans = func(N)
measured.append(('call_func_1st', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_2nd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_3rd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_4th', perf_counter(), ans))

for i in range(len(measured)):
    if i == 0:
        continue
    elapsedtime = str(round(measured[i][1] - measured[i - 1][1], 3))
    print(elapsedtime, measured[i][0], measured[i][2], sep='\t')
i = 4
elapsedtime = str(round(measured[i][1] - measured[0][1], 3))
print(elapsedtime, 'from-begin to-' + measured[i][0], sep='\t')
elapsedtime = str(round(measured[-1][1] - measured[0][1], 3))
print(elapsedtime, 'total', sep='\t')

測定結果

(単位: sec)

コード ライブラリ
インポート
クラス定義 関数定義
(AOTは読込)
初回呼出 2回目呼出
JITキャッシュなし
(型指定なし)
0.325 0.001 0.000 1.187 0.000
JITキャッシュなし
(型指定あり)
0.326 0.001 1.191 0.000 0.000
JITキャッシュ使用
(型指定あり)
0.329 0.001 0.078 0.000 0.000
AOT 0.000 0.000 0.083 0.000 0.000
  • 型指定の有無でコンパイルのタイミングが変わることに注意。
  • キャッシュでコンパイル相当部分の時間が大幅に減った。おそらくクラスもキャッシュできたのだろう。
  • AOTは構造からして呼出元関数のコンパイル結果にjitclassが同梱されなければコードが完走しないため、AOTでの成功は確定。

キャッシュ、AOTともに80msほど残るのが気になるので、キャッシュとAOTの最低オーバーヘッドも確認することにする。

キャッシュとAOTの最低オーバーヘッド

呼出元関数を、jitclassを呼び出さずただのNumpy配列を使う関数に置き換え比較用として測定した。

元のコードのjitclass呼出元関数

@jitclass(spec)
class Buffer():
    メソッド定義...

@jit(int64(int64), nopython=True, cache=True)
def func(x):
    a = Buffer()    # jitclassの呼出
    a.append(x * 10)
    v = a.popleft()
    return v

比較用の自作クラス不使用関数

''' クラス定義をコメントアウト
@jitclass(spec)
class Buffer():
    メソッド定義...
'''
@jit(int64(int64), nopython=True, cache=True)
def func(x):
    a = np.zeros(4, dtype=np.int64)    # 自作クラスを使わない
    a[1] = x * 10
    v = a[1]
    return v
比較用コード(自作クラス不使用)-JIT (クリックで展開/折りたたみ)
N = 10


from time import perf_counter
measured = []
measured.append(('begin', perf_counter(), 0))

import numpy as np
from numba import jit, int64, typeof
try:
    from numba.experimental import jitclass
except ImportError:
    from numba import jitclass
measured.append(('import_libraries', perf_counter(), 1))

''' # クラス定義をコメントアウト
spec = [('_top', int64),
        ('_bottom', int64),
        ('_arraysize', int64),
        ('_dtype', typeof(np.int64)),
        ('_array', int64[:]),]
@jitclass(spec)
class Buffer():
    def __init__(self):
        self._top = 0
        self._bottom = 0
        self._arraysize = 4
        self._dtype = np.int64
        self._array = np.zeros(self._arraysize, dtype=self._dtype)

    @property
    def nholdings(self):
        return self._top - self._bottom

    def pop(self):
        if self.nholdings <= 0:
            raise Exception
        ret = self._array[self._ind(self._top - 1)]
        self._top -= 1
        return ret

    def popleft(self):
        if self.nholdings <= 0:
            raise Exception
        ret = self._array[self._ind(self._bottom)]
        self._bottom += 1
        return ret

    def append(self, value):
        if self.nholdings >= self._arraysize:
            self._extend()
        self._array[self._ind(self._top)] = value
        self._top += 1
        return None

    def _extend(self):
        oldsize = self._arraysize
        new = np.zeros(oldsize * 2, dtype=self._dtype)
        if self._ind(self._bottom) < self._ind(self._top):
            new[0: oldsize] = self._array
        else:
            new[self._ind(self._bottom): oldsize] = (
                self._array[self._ind(self._bottom): oldsize])
            new[oldsize: oldsize + self._ind(self._top)] = (
                self._array[0: 0 + self._ind(self._top)])
        self._array = new
        self._arraysize *= 2
        return None

    def _ind(self, pos):
        return pos % self._arraysize
'''
measured.append(('define_jitclass', perf_counter(), 2))

'''切換え用
@jit(nopython=True)    # キャッシュなし(型指定なし)
@jit(int64(int64), nopython=True, cache=False)    # キャッシュなし(型指定あり)
@jit(int64(int64), nopython=True, cache=True)    # キャッシュあり(型指定あり)
'''

@jit(int64(int64), nopython=True, cache=True)    # キャッシュあり(型指定あり)
def func(x):
    a = np.zeros(4, dtype=np.int64)    # 自作クラスを使わない
    a[1] = x * 10
    v = a[1]
    return v

measured.append(('define_func', perf_counter(), 3))


ans = func(N)
measured.append(('call_func_1st', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_2nd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_3rd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_4th', perf_counter(), ans))

for i in range(len(measured)):
    if i == 0:
        continue
    elapsedtime = str(round(measured[i][1] - measured[i - 1][1], 3))
    print(elapsedtime, measured[i][0], measured[i][2], sep='\t')
i = 4
elapsedtime = str(round(measured[i][1] - measured[0][1], 3))
print(elapsedtime, 'from-begin to-' + measured[i][0], sep='\t')
elapsedtime = str(round(measured[-1][1] - measured[0][1], 3))
print(elapsedtime, 'total', sep='\t')

--- 展開ここまで ---

比較用コード(自作クラス不使用)-AOT (クリックで展開/折りたたみ)
N = 10


from time import perf_counter
measured = []
measured.append(('begin', perf_counter(), 0))

pass # import numpy as np # numpy のインポートもスキップ

measured.append(('import_libraries', perf_counter(), 1))
measured.append(('define_jitclass', perf_counter(), 2))

try:
    from my_module import func
except ImportError:

    print('---compiling...---')

    import numpy as np
    import numba
    from numba import int64, typeof

    try:
        from numba.experimental import jitclass
    except ImportError:
        from numba import jitclass
    from numba.pycc import CC

    '''
    spec = [('_top', int64),
            ('_bottom', int64),
            ('_arraysize', int64),
            ('_dtype', typeof(np.int64)),
            ('_array', int64[:]),]
    @jitclass(spec)
    class Buffer():
        def __init__(self):
            self._top = 0
            self._bottom = 0
            self._arraysize = 4
            self._dtype = np.int64
            self._array = np.zeros(self._arraysize, dtype=self._dtype)

        @property
        def nholdings(self):
            return self._top - self._bottom

        def pop(self):
            if self.nholdings <= 0:
                raise Exception
            ret = self._array[self._ind(self._top - 1)]
            self._top -= 1
            return ret

        def popleft(self):
            if self.nholdings <= 0:
                raise Exception
            ret = self._array[self._ind(self._bottom)]
            self._bottom += 1
            return ret

        def append(self, value):
            if self.nholdings >= self._arraysize:
                self._extend()
            self._array[self._ind(self._top)] = value
            self._top += 1
            return None

        def _extend(self):
            oldsize = self._arraysize
            new = np.zeros(oldsize * 2, dtype=self._dtype)
            if self._ind(self._bottom) < self._ind(self._top):
                new[0: oldsize] = self._array
            else:
                new[self._ind(self._bottom): oldsize] = (
                    self._array[self._ind(self._bottom): oldsize])
                new[oldsize: oldsize + self._ind(self._top)] = (
                    self._array[0: 0 + self._ind(self._top)])
            self._array = new
            self._arraysize *= 2
            return None

        def _ind(self, pos):
            return pos % self._arraysize
    '''

    def func(x):
        a = np.zeros(4, dtype=np.int64)    # 自作クラスを使わない
        a[1] = x * 10
        v = a[1]
        return v


    cc = CC('my_module')
    cc.export('func', 'int64(int64)')(func)
    cc.compile()
    print('---compile end---')
    exit(0)
else:
    pass

measured.append(('define_func', perf_counter(), 3))


ans = func(N)
measured.append(('call_func_1st', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_2nd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_3rd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_4th', perf_counter(), ans))

for i in range(len(measured)):
    if i == 0:
        continue
    elapsedtime = str(round(measured[i][1] - measured[i - 1][1], 3))
    print(elapsedtime, measured[i][0], measured[i][2], sep='\t')
i = 4
elapsedtime = str(round(measured[i][1] - measured[0][1], 3))
print(elapsedtime, 'from-begin to-' + measured[i][0], sep='\t')
elapsedtime = str(round(measured[-1][1] - measured[0][1], 3))
print(elapsedtime, 'total', sep='\t')

--- 展開ここまで ---

測定結果

(単位: sec)

コード ライブラリ
インポート
クラス定義 関数定義
(AOTは読込)
初回呼出 2回目呼出
jitclass呼出、キャッシュなし
(型指定あり)
0.326 0.001 1.191 0.000 0.000
クラス不使用、キャッシュなし
(型指定あり)
0.311 0.000 0.107 0.000 0.000
jitclass呼出、キャッシュ使用
(型指定あり)
0.329 0.001 0.078 0.000 0.000
クラス不使用、キャッシュ使用
(型指定あり)
0.309 0.000 0.075 0.000 0.000
jitclass呼出、AOT 0.000 0.000 0.083 0.000 0.000
クラス不使用、AOT 0.000 0.000 0.082 0.000 0.000

自作クラス不使用の結果を見るとキャッシュやAOTでも関数部分のコンパイル結果の読込だけで80msほどかかることが分かる。
それをjitclass使用の結果から差し引くと、キャッシュ/AOTでクラス分のコンパイル時間もゼロになったとみなしてよさそう。

追試: クラスのコンパイルを重くしてみる

jitclassのメソッドに無駄な処理を追加し、クラスのコンパイルを重くしてみた。
np.sortがあるとコンパイルに時間がかかるらしい (参考) ので、使用されない分岐先をif文で作り追加した。

処理追加コード-JIT (クリックで展開/折りたたみ)
N = 10


from time import perf_counter
measured = []
measured.append(('begin', perf_counter(), 0))

import numpy as np
from numba import jit, int64, typeof
try:
    from numba.experimental import jitclass
except ImportError:
    from numba import jitclass
measured.append(('import_libraries', perf_counter(), 1))


spec = [('_top', int64),
        ('_bottom', int64),
        ('_arraysize', int64),
        ('_dtype', typeof(np.int64)),
        ('_array', int64[:]),]
@jitclass(spec)
class Buffer():
    def __init__(self):
        self._top = 0
        self._bottom = 0
        self._arraysize = 4
        self._dtype = np.int64
        self._array = np.zeros(self._arraysize, dtype=self._dtype)

    @property
    def nholdings(self):
        return self._top - self._bottom

    def pop(self):
        if self.nholdings <= 0:
            raise Exception
        ret = self._array[self._ind(self._top - 1)]
        self._top -= 1
        return ret

    def popleft(self):
        if self.nholdings <= 0:
            raise Exception
        ret = self._array[self._ind(self._bottom)]
        self._bottom += 1
        return ret

    def append(self, value):
        if self.nholdings >= self._arraysize:
            self._extend()
        self._array[self._ind(self._top)] = value
        self._top += 1
        return None

    def _extend(self):
        if self.nholdings < -9999:
            self._array = np.sort(self._array)    # コンパイルを遅くする
        oldsize = self._arraysize
        new = np.zeros(oldsize * 2, dtype=self._dtype)
        if self._ind(self._bottom) < self._ind(self._top):
            new[0: oldsize] = self._array
        else:
            new[self._ind(self._bottom): oldsize] = (
                self._array[self._ind(self._bottom): oldsize])
            new[oldsize: oldsize + self._ind(self._top)] = (
                self._array[0: 0 + self._ind(self._top)])
        self._array = new
        self._arraysize *= 2
        return None

    def _ind(self, pos):
        return pos % self._arraysize

measured.append(('define_jitclass', perf_counter(), 2))

'''切換え用
@jit(nopython=True)    # キャッシュなし(型指定なし)
@jit(int64(int64), nopython=True, cache=False)    # キャッシュなし(型指定あり)
@jit(int64(int64), nopython=True, cache=True)    # キャッシュあり(型指定あり)
'''

@jit(int64(int64), nopython=True, cache=True)    # キャッシュあり(型指定あり)
def func(x):
    a = Buffer()
    a.append(x * 10)
    v = a.popleft()
    return v

measured.append(('define_func', perf_counter(), 3))


ans = func(N)
measured.append(('call_func_1st', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_2nd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_3rd', perf_counter(), ans))
ans = func(N)
measured.append(('call_func_4th', perf_counter(), ans))

for i in range(len(measured)):
    if i == 0:
        continue
    elapsedtime = str(round(measured[i][1] - measured[i - 1][1], 3))
    print(elapsedtime, measured[i][0], measured[i][2], sep='\t')
i = 4
elapsedtime = str(round(measured[i][1] - measured[0][1], 3))
print(elapsedtime, 'from-begin to-' + measured[i][0], sep='\t')
elapsedtime = str(round(measured[-1][1] - measured[0][1], 3))
print(elapsedtime, 'total', sep='\t')

--- 展開ここまで ---

測定結果

(単位: 秒)

コード ライブラリ
インポート
クラス定義 関数定義
(AOTは読込)
初回呼出 2回目呼出
元のコード、キャッシュなし
(型指定あり)
0.326 0.001 1.191 0.000 0.000
追加後、キャッシュなし
(型指定あり)
0.327 0.001 2.090 0.000 0.000
元のコード、キャッシュ使用
(型指定あり)
0.329 0.001 0.078 0.000 0.000
追加後、キャッシュ使用
(型指定あり)
0.325 0.001 0.079 0.000 0.000

コード追加でキャッシュなしは900ms増加なのに対してキャッシュありは変化なし。jitclass分のキャッシュに成功していると確信してよいだろう。

注意: キャッシュを利用したい場合はjitclassの呼出をキャッシュ済みNumba化関数から行うこと

実行時にjitclassを素のPythonや非キャッシュのNumba化関数から呼び出してしまうとキャッシュが利用されない。jitclassのキャッシュは呼出元のNumba化関数のキャッシュに一体化されるようだ。

@jitclass(spec)
class Buffer():
    ''' メソッド...'''

@jit(nopython=True, cache=True)
def nbfunc(x):
    a = Buffer()    # jitclassをNumba化関数から呼ぶと、
    ''' 処理... ''' # クラスのコンパイル結果が関数のキャッシュに同梱・利用される
nbfunc()


def pyfunc(x):
    a = Buffer()    # jitclassを素Pythonから呼ぶと再コンパイルになる
    ''' 処理... '''
pyfunc()
@jitclass(spec)
class Buffer():
    ''' メソッド...'''

@jit(nopython=True, cache=True)
def nmbfunc(x):
    a = Buffer()    
    ''' 処理... '''
nmbfunc()

@jit(nopython=True)    
def nmbfunc2():
    a = Buffer()
    ''' 処理... '''
nmbfunc2()    #それどころか別のNumba化関数で非キャッシュのものから呼び出しても再コンパイルになる

おわり

Numba化された関数から@jitclassされたクラスを呼び出す構造で元の関数をキャッシュ/AOTすると、関数のコンパイル結果にjitclassのコンパイル結果も含められてキャッシュ/AOTされることを確かめた。

関数のコンパイル時間とクラスのコンパイル時間の分離ができないのでキャッシュ成否の確証を得るのに手間取った。もっと大規模なクラスでテストした方が楽だったかもしれない。

Numbaでクラスを使うのは色々大変だけど可能性が広がった。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0