40
32

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Python を Numba で高速化するときの間違えやすいポイントまとめ

Last updated at Posted at 2022-05-08

はじめに

決まればPythonを劇的に速くするNumba。
使ってみて細かなはまりポイントがあったので注意点を集約してみた。

変更履歴

(2023-12-20) heapqを使うときの注意点を追加
(2023-12-20) 再帰関数をキャッシュや事前コンパイルできたので言及を訂正
(2023-12-20) 「Numba内から外へアクセス」について各項目を見出しに昇格
(2022-05-13) Numpyのdtype指定の注意点を追加
(2022-05-12) クラス関係としてjitclassへの言及を追記
(2022-05-11) 戻り値を統一する必要があること、例外が使えることを追加
(2022-05-08) 投稿

要約を兼ねた目次

1. ざっくり使い方

基本は from numba import jit して対象の関数に @jit(nopython=True) するだけ。公式のお手本が分かりやすい。

from numba import jit
import numpy as np

x = np.arange(100).reshape(10, 10)

@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a): # Function is compiled to machine code when called the first time
    trace = 0.0
    for i in range(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting

print(go_fast(x))

https://numba.readthedocs.io/en/stable/user/5minguide.html#will-numba-work-for-my-code

うまく決まればC言語と争うほどの速度になるそうだ。(yniji氏による測定: Python を高速化する Numba, Cython 等を使って Julia Micro-Benchmarks してみた)

1-1. 基本的な制限

対応していない機能が多いので、変数型・関数・文法などすべての要素で注意が必要。

  • 基本的に、数値型 (真偽値含む)・NumPy配列・それらのタプル以外は扱わないのが無難。
    • 使える関数については、数値型の組み込み関数math関数NumPy関数ならだいたいは対応している模様。
    • print関数は使え、複数引数にも対応している。sep引数、end引数は使えない。
  • listはある程度動くが非対応の動作が多くてトラブルになりやすく、速度も不利。Numpy配列が向かない目的に最小限で。
  • dictの使い方もかなり限定される。
  • クラス関係の扱いは著しく限定される。自作クラスをNumba化(@jitclass)することもできるが特別な書き方を必要とし機能制限も多い。
  • 文字列は一応受け入れるが遅くなりやすい。Numba関数には触れさせず内部でも一切使わないのが無難。
  • 各種文法についても初心者にみせるくらいの気持ちで平易にした方が安全。
  • 関数内にreturn文が複数ある場合は返り値の型を揃える。統一できない場合にはError送出を使える。

素のPythonからかなり機能が限定される一方、コードのうちよく効く部分のみに簡単にかけられるのが大きな強み。

2. @jitの引数

基本的に @jit(nopython=True, cache=True) でよい。
@njit(cache=True) でも同じ。

2-1. 必ずnopythonモードで利用する

これは挙げている記事も多い。@jit(nopython=True) で動かないコードは @jit() で動かしてもかなり遅くなる。非対応の動作を @jit に入れてしまっているはずなので切り分けて確認しよう。

なお、 @njit()@jit(nopython=True) と同じ。

2-2. コンパイル時間対策にキャッシュが有効

@jit(nopython=True) ではプログラム終了時にコンパイル結果が破棄され、再実行の度にコンパイル時間がかかってしまう。 @jit(nopython=True, cache=True) とするとコンパイル結果がキャッシュファイルに保存され、プログラムを再実行したときのコンパイル時間がなくなる。長い処理では当然有効だが、短い処理でもコンパイルのオーバーヘッドと高速化とを天秤にかけて@jitするか迷う必要がなくなる。実用上かなりおすすめ。

2-3. 通常の使用に型指定は不要

詳しくは4. 実行時間を測るときかなり癖があるで説明するが、型指定しても指定なしより速度は向上しない。公式の説明でも型指定なしが推奨されている。既定の動作では@jitされた関数が呼び出されたとき、つまりPythonから引数を渡されたときにコンパイルを行うため、型指定なしでもNumbaはコンパイル開始時に引数型を把握できる。

2-4. その他の引数

公式のヒントでさらなる高速化の方法が説明されている。parallel=Trueで並列処理 (GILなし)、fastmath=Trueで数値計算での制約を緩和 (?) など。

3. 基本的にNumba内から外へアクセスしない

3-1. グローバル変数の読み取りや変更は危険

  • 外部の変数を扱うとエラーになったり過去の値で計算したりトラブルの元

コード例 - 外部変数の書き換え

import numpy as np
from numba import jit

outerVar = np.array([4, 5, 6, 7 ])

@jit(nopython=True, cache=True)
def nmbfunc(a):
    outerVar[2] = a    # Numba外へのアクセスはだめ
    return

nmbfunc(8)  # エラー!
print(outerVar)

コード例 - 外部変数の読み取り

from numba import jit

outerVar = 100

@jit(nopython=True, cache=True)
def nmbfunc(a):
    r = a + outerVar   # Numba外へのアクセスはだめ
    return r

print(nmbfunc(1))    # 101 一見正常動作だが

outerVar = 200

print(nmbfunc(1))    # 101 誤った出力! 変更が反映されていない

変数は全て引数で渡すか、

import numpy as np
from numba import jit

outerVar = np.array([4, 5, 6, 7 ])

@jit(nopython=True, cache=True)
def nmbfunc(a, var):
    var[2] = a
    return

nmbfunc(10, outerVar)    # 引数で渡せば動く
print(outerVar)    # [ 4  5 10  7]

もしくは全体を大きな関数で括り外側関数に @jit する

import numpy as np
from numba import jit

@jit(nopython=True, cache=True)
def nestfunc():

    outerVar = np.array([4, 5, 6, 7 ])

    def innerfunc(a):
        outerVar[2] = a   # Numba領域内であれば関数外へのアクセスは可能
        return

    innerfunc(10)
    print(outerVar)    # [ 4  5 10  7]

nestfunc()

3-2. Numba内からの関数呼出しはNumba対応のものに限定

Numba関数内から呼出し可能なのは、原則、Numbaライブラリに準備されている関数 (主に組み込み関数、数値系の標準ライブラリ、Numpy関数) か、自分で @jit してNumba化した関数。

  • @jitした関数から、外の他の@jitのない関数を呼び出すことも不可 (エラー)
from numba import jit

def anotherpyfunc(a):
    return a * 10

@jit(nopython=True, cache=True)
def nmbfunc(n):
    r = anotherpyfunc(n)    # これだけでエラー
    return r

print(nmbfunc(3))
  • @jitした関数から別の@jitした関数を呼び出す事は可能
from numba import jit

@jit(nopython=True, cache=True)
def anothernmbfunc(a):
    return a * 10

@jit(nopython=True, cache=True)
def nmbfunc(n):
    r = anothernmbfunc(n)    # Numba関数同士は呼び出せる
    return r

print(nmbfunc(3))    #30

3-3. 再帰関数は独自の制限がある

「全体を大きな関数で括り外側関数に@jitする」が使えない、 「コンパイル結果のキャッシュや事前コンパイルがつかえない」 など。
外へのアクセスの一種とみなされるらしい。参考: メモ化再帰DPでTLEを避けるには - 西尾泰和のScrapbox
(2023-12-20 変更) 再帰関数もキャッシュできたし、別のNumba関数を経由することで事前コンパイルもできました。

3-4. クラスやメソッドもNumba対応のものに限定

クラスの変数やメソッドはこれらに加えてクラス関係の制約が加わる。
インスタンス生成やメソッド呼出しが可能なのはやはりNumbaライブラリに準備されているものに限られる。(クラスは基本的に組み込み型の一部とnumpy.ndarray。メソッドはそれらのクラスのものの一部)
クラス自体をNumba化 (@jitclass) させれば呼出可能だが、型指定が必要なほか機能制限も多い。

4. 実行時間を測るときかなり癖がある

4-1. @jitで型指定しても速度は向上しない、しない方が速いかもしれない

測定の方法が不適切で「型指定したら速くなった」と誤解している例をよく見かける。以下のように各処理にかかる時間をすべて可視化すると実態が見えてくる。

計測コード (クリックで展開/折りたたみ)
measelapsedtime.py
# 元データの準備
N = int(input())

import numpy as np
np.random.seed(0)
arrOrigin = np.random.randint(low=0, high=1000, size=N, dtype=np.int64)
arr1 = arrOrigin.copy()
arr2 = arrOrigin.copy()
arr3 = arrOrigin.copy()
arr4 = arrOrigin.copy()
arr_i1 = arrOrigin.copy()
arr_i2 = arrOrigin.copy()
arr_f1 = arrOrigin.copy()
arr_f2 = arrOrigin.copy()


# 計測開始 (time.time()よりtime.perf_counter()の方が正確らしい)
from time import perf_counter

measured = []
measured.append(('begin', perf_counter(), 0))    # [0]

from numba import jit
measured.append(('import_Numba', perf_counter(), 0))    # [1]

@jit(nopython=True)    # 型指定なし用
# @jit('f8(i8[:], i8, f8)', nopython=True)    # 型指定あり用
def nmbfunc(arr, param_i, param_f):
    arr.sort()
    s = 0
    for i in range(arr.size):
        if i < arr.size//2:
            s -= arr[i]
        else:
            s += arr[i]
    return s * 1000 + param_i + param_f
measured.append(('define_function', perf_counter(), 0))    # [2]


ans = nmbfunc(arr1, 1, 0.5)
measured.append(('1st_call', perf_counter(), ans))    # [3]

ans = nmbfunc(arr2, 2, 0.5)
measured.append(('2nd_call', perf_counter(), ans))

ans = nmbfunc(arr3, 3, 0.5)
measured.append(('3rd_call', perf_counter(), ans))

ans = nmbfunc(arr4, 4, 0.5)
measured.append(('4th_call', perf_counter(), ans))

try:
    ans = nmbfunc(arr_i1, 1, 2)
    measured.append(('args:(i to f)-1st', perf_counter(), ans))

    ans = nmbfunc(arr_i2, 2, 2)
    measured.append(('args:(i to f)-2nd', perf_counter(), ans))
except Exception as err:
    measured.append(('args:(i to f):', perf_counter(), type(err)))

try:
    ans = nmbfunc(arr_f1, 1.5, 0.5)
    measured.append(('args:(f to i)-1st', perf_counter(), ans))

    ans = nmbfunc(arr_f2, 2.5, 0.5)
    measured.append(('args:(f to i)-2nd', perf_counter(), ans))
except Exception as err:
    measured.append(('args:(f to i):', perf_counter(), type(err)))


print('elapsedtime:')

for i in range(len(measured)):
    if i ==0:
        continue
    elapsedtime = str(round(measured[i][1] - measured[i-1][1], 3)) + '\t'
    print(measured[i][0] + '\t', elapsedtime, measured[i][2])

print('from begin to ' + measured[3][0] + ' end \t', measured[3][1] - measured[0][1])
print('total' + '\t', perf_counter() - measured[0][1])

測定結果 N = 10 ** 6 (単位: sec)

指定 型指定なし 'f8(i8[:], i8, f8)'
インポート 0.268 0.264
関数定義 0.000 0.758
初回呼出 0.811 0.068
開始から初回呼出完了まで 1.079 1.089
呼出2回目 0.057 0.067
呼出3回目 0.057 0.067
呼出4回目 0.057 0.067
別の型で呼出
float引数にint値 初回
0.210 0.068
float引数にint値 2回目 0.057 0.068
別の型で呼出
int引数にfloat値 初回
0.208 0.068
int引数にfloat値 2回目 0.057 0.068
  • 初回呼出の時間を比べると型指定で非常に速くなるように見えるが、そこまでを含めた合計はほぼ同じ。

  • 要するに型指定ありだと関数定義時にコンパイルされ、型指定なしだと新しい引数の型の組合せで呼出した時にコンパイルされる模様。初回呼出の時間はコンパイル時間を含んだり含まなかったりややこしいので呼出2回目以降で評価するのがいい。

  • 型指定なし関数にそれまでと異なる型の引数を渡すと再度コンパイルが走るが、初回よりはるかに少ない時間で済むことがある。必要な部分に限定して再コンパイルする機能でもあるのだろうか。

  • いずれにしても実際に使うときはcache=Trueが有効。

測定結果 N = 10 ** 7 (単位: sec)

指定 型指定なし 'f8(i8[:], i8, f8)'
インポート 0.399 0.281
関数定義 0.000 0.738
初回呼出 1.361 0.710
開始から初回呼出完了まで 1.748 1.742
呼出2回目 0.621 0.712
呼出3回目 0.625 0.719
呼出4回目 0.623 0.724
別の型で呼出
float引数にint値 初回
0.776 0.725
float引数にint値 2回目 0.626 0.727
別の型で呼出
int引数にfloat値 初回
0.773 0.726
int引数にfloat値 2回目 0.624 0.726
  • 実際の計算時間は型指定なしの方がわずかに速い模様。 計算内容や型指定記述などを変えて色々試したものの同様の傾向であり、原因や型指定ありで同じ速度にする方法は分からなかった。知っている人がいたら教えて下さい。

以上より、普通の使い方での@jitによる高速化に型指定は不要、むしろしない方が速いかもしれないといえる。また、公式の高速化ヒントを見ても高速化のために型指定を推奨するような記述は見られない。

4-2. 競技プログラミング (AtCoder) の場合は対策がほしい

許される実行時間が短いのでコンパイルを評価時間外に行いたい。事前キャッシュ作成やAOTコンパイルで対策できる。yniji氏の記事 「AtCoderで Python を高速化する Numpy + Numba を使う」 が詳しい。

勘違いしやすいが、事前キャッシュ作成に対するAOTの利点はNumbaライブラリの読込時間を省略できることが主。Numbaライブラリ読込は数百msかかる上、どうやらばらつきが大きそう。

5. 型指定の記述方法

上述のとおり型指定はまず不要だが、場合により型指定したい場合もある。ここも一見単純ながらはまりポイントがある。

5-1. 記述方法が4種類くらいある

文字列で与える方式と型指定オブジェクトをインポートして渡す方式があり、それぞれに短縮名と通常名がある。知らずに参考例を探しているとそれぞれ微妙に異なって混乱する。

以下の4つは記述が微妙に異なるが効果は同じ。

from numba import jit

@jit('f8[:](f8,i8,b1,f8[:,:])', nopython=True, cache=True)    # お手軽
from numba import jit, f8, i8, b1

@jit(f8[:](f8,i8,b1,f8[:,:]), nopython=True, cache=True)
from numba import jit

@jit('float64[:](float64, int64, boolean, float64[:,:])', nopython=True, cache=True)
from numba import jit, float64, int64, boolean

@jit(float64[:](float64, int64, boolean, float64[:,:]), nopython=True, cache=True)
# 発展的な機能を使う際に有用

型指定オブジェクトを渡す方式での記述をそのまま引用符で囲めば文字列方式になるようだ。

5-2. タプルは括弧を二重にする

  • @jit('i8(Tuple(f8,i8,b1),f8[:])', nopython=True, cache=True) # これはエラー
  • @jit('i8(Tuple((f8,i8,b1)),f8[:])', nopython=True, cache=True) # これは動く

初見ではまず分からない。おそらく「関数Tupleは1つの引数をとる。型指定オブジェクトのタプルをTuple関数に渡す。そして全体を文字列方式にする」なのだろう。

5-3. 返り値型を省略かつ引数が1個のときカンマが必要

  • @jit('f8(i8)', nopython=True, cache=True) # 返り値型あり
  • @jit('(i8)', nopython=True, cache=True) # エラー
  • @jit('(i8,)', nopython=True, cache=True) # 動く
  • @jit('i8,', nopython=True, cache=True) # 実はこれでも動く

返り値型省略時は記述をPython文法で解釈したときに「引数の型指定のタプル」になっている必要があるようだ。

6. 空コンテナに注意

Numbaが型推論に失敗してエラーになることがあるらしい。

6-1. heapqを使うとき

(2023-12-20 追加) Numbaでも heapq (優先度付きキュー)を使用可能だが、空リストに heappush するとエラーになる。
しかもエラーメッセージは別の箇所で原因が起きたかのように表示される。

エラーになる例

from heapq import heappop, heappush
from numba import jit

@jit(nopython=True, cache=True)
def use_heapq():
    hq = []
    heappush(hq, 5)    # 空リストにheappushすると後でエラー
    heappush(hq, 8)
    elem = heappop()
    return elem

print(use_heapq())

動作する例

from heapq import heappop, heappush
from numba import jit

@jit(nopython=True, cache=True)
def use_heapq():
    hq = [5]    # リストに1つ目の要素を入れてからheap関数に渡すと動く
    heappush(hq, 8)
    elem = heappop(hq)
    return elem

print(use_heapq())

7. Numbaのバージョン違いによるトラブル

いくらでもあることだけど、例えばAtCoderのNumba 0.48と最近のNumbaで異なる点を挙げる。

7-1. Numpy配列のイテレーション (2次元以上?の場合)

AtCoder (Numba 0.48) では不可、最近のNumbaでは可能。

コード例 (クリックで展開/折りたたみ)
from numba import jit
import numpy as np

@jit(nopython=True, cache=True)
def funciter(arr2d):
    s = 0
    for arrflat in arr2d:    # ここ
        for v in arrflat:
            s += v
    return s

nda = np.array([2, 4, 6, 8 ]).reshape(2,2)
print(funciter(nda))
# Numba 0.53 だと動く
# Numba 0.48 (AtCoder) だとエラー
対応: インデックスアクセスで回す。素Pythonと異なりNumbaはインデックスアクセスが速い。手元での計測(Numba 0.53)だとインデックスアクセスの方がやや速かった。
対応例コード (クリックで展開/折りたたみ)
from numba import jit
import numpy as np

@jit(nopython=True, cache=True)
def funciter(arr2d):
    s = 0
    for i in range(arr2d.shape[0]):    # ここ
        for j in range(arr2d.shape[1]):
            s += arr2d[i, j]
    return s

nda = np.arange(12).reshape(4,3)
print(funciter(nda))
# Numba 0.48 (AtCoder) でも動く

7-2. jitclassの位置

クラスをNumba化する@jitclassもあるが、Numbaのバージョンによって位置が異なる。

  • 古めのバージョン (AtCoder の 0.48 など) では numba.jitclass
  • 最近のバージョンでは numba.experimental.jitclass

7-3. NumPyのdtype指定は文字列でなくNumPy型オブジェクト渡しが確実

本来のNumPyではdtype引数に例えばnp.int64'int64'intどれを渡しても動くが、Numbaではnp.int64が確実。

トラブル例: 文字列渡し (クリックで展開/折りたたみ)
import numpy as np
from numba import jit

@jit(nopython=True)
def func():
    r = np.zeros(4, dtype='int64')  # dtype指定を文字列ですると...
    return r

ans = func()    # Numba 0.48 ではエラー、最近のNumbaは動く
print(ans)
トラブル例: Python型オブジェクト渡し (クリックで展開/折りたたみ)
import numpy as np
from numba import jit

@jit(nopython=True)
def func():
    r = np.zeros(4, dtype=int)  # dtype指定をPython型オブジェクト渡しですると...
    return r

ans = func()    # Numba 0.48 でも最近のNumbaでもエラー
print(ans)
対応例: NumPy型オブジェクト渡し (クリックで展開/折りたたみ)
import numpy as np
from numba import jit

@jit(nopython=True)
def func():
    r = np.zeros(4, dtype=np.int64)  # dtype指定はnp.型オブジェクト渡しが確実
    return r

ans = func()    # Numba 0.48 でも最近のNumbaでも動く
print(ans)

8. 公式ドキュメントのサイトが2つある

ドキュメントだけ別サイトに引っ越して古い方を放置している模様。

検索では古い方が引っかかりやすいようなので最新の情報を探すときは注意。

ドキュメント以外は元のサイトのままっぽい。

終わり

他にミスりやすいポイントがあったら教えてください。

Numba はすごいのでみんなNumbaを使おう。

40
32
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
40
32

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?