9
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Numbaを使った際に困ったこと

Last updated at Posted at 2024-08-29

自分用のメモです
CPUを使う範囲では、まあまあの時間使用してきたので(時間だけかも)、メモを残しておこうと思います。
ただの物理学科の大学院生が、ほぼ独学で学んだことなので、まあ参考にできるところもあれば、ならないところもあるかなくらいで見ていただけると、と思います。
ここは違うだろってところは教えていただきたいです。

version

M1 macbookair, macOS=14.6.1(23G93), conda=23.7.4(あんま関係ないか)
Python=3.10.14, numba=0.60.0, numpy=1.26.4

0. numbaとは、

他のサイトで読むまとまっているところが多いし、深く語るだけの知識もないので、ちょろっとだけ

先にまとまっているところ

メリット

  • ほぼpython+numbaの書き方でかなりの高速化ができる
  • installはすでにpythonを使用していれば(CPU使用の範囲なら)めっちゃ簡単
  • 他サイトによるとC言語の1/2くらいの速度は出るらしい。原理的にjuliaと同じくらいか?

デメリット

  • python code → numbaはなかなかそのままでは動かない
    • numpyの関数がnumbaにないこと多数、axisは基本使えない
    • 並列化をする場合、そりゃあpythonそのままではできるわけがない
    • そもそもpythonではforを使わない方がいいとされるが、numbaでは使った方がいい
      そのため、python code → numbaでそのまま動いても最速ではないことが多い
  • numba code → pythonはほぼ確実に動くが、numbaでの速いコードはpythonのみだと遅い
    • pythonではforを使わずnumpyでなるべく処理した方が速いが、
      そもそもどうしてもforを使わなければならないが高速にしたい時にnumbaを使う感覚なので
      どうしてもnumba code → pythonは遅くなる
      加えてnumpyの関数がnumbaにないことも多いので、その分も遅くなる。
  • 頑張って早くしてもcやfortranの方が速い(と思われる)
  • そもそもnumbaで最速を目指すとCみたいな書き方が推奨される

1. 基本的に。

基本はnjit,prange,set_num_threads,objmode,get_num_threadsしか使わない。
ほぼ全部とりあえず、@njit(error_model="numpy")をつければ早くなるかエラーになるか。
@jit@njitがあるが、
numbaに対応していないものがある時には

  • @jit : とりあえず動かす。元のpythonより遅くなることも
  • @njit : エラーを吐く。

つける関数を早くするのが目的だろうから、@njitを使って、エラー吐いたら治す。
numba関数内では対応した関数と自作numba関数しか使用できない。
意外にもmatch caseとかの新しいpython関数も使用可能。
dictは使用可能(公式的にはなんか試作段階ってどこかで言っていたような気がするが、困ったことはない)。後述。

2. 並列化

cpu使用数をset_num_threads(n)で決める。njitの中でも外でもok
(たぶん)hyperthreading,SMTは使えない? 今の所使う方法が自分はわからないです。
並列化したいところでprangeを使用してfor文を作る。基本形は多分こう

from numba import njit,prange,set_num_threads
@njit(parallel=True,error_model="numpy")
def parallel_func(num_threads,random_x):
    set_num_threads(num_threads)
    loop=random_x.shape[0]
    for i in prange(loop):
        njit_func(random_x[i])

get_thread_id()でthread_idを取得できるが、それを見ると結構バラバラに走っている。(いつのバージョンからか直っていました。)
上のように実行すると、各スレッドではnp.array_split(np.arange(loop),num_threads)に沿ってiに値が入ります(分からなければ下の確認用コードを見てね)。
つまり、loop=80,num_threads=8だとi=0~9が同じコアで実行されます。
で、例えば計算がどこかのコアで早く終わった時、そのコアは別に他のコアの仕事を手伝ってくれたりしないため、無駄が生じます。
例えばiが大きくなると処理に時間がかかるような場合、thread_id=0は早く終わって休んでて、thread_id=7が最後まで計算している、のようになります。そんな時には

@njit(parallel=True,error_model="numpy")
def parallel_func(num_threads,random_x):
    set_num_threads(num_threads)
    loop=random_x.shape[0]
    for thread_id in prange(num_threads):
        for i in range(thread_id,loop,num_threads):
            njit_func(random_x[i])

のようにするといいです(ストライドっていうらしい、このページにあった)。
わざわざget_thread_id()を使わなくていいので、かなりお気に入り。

確認用コード(折りたたみ)
import numpy as np
from numba import njit,prange,set_num_threads,get_thread_id
    
@njit(parallel=True,error_model="numpy")
def parallel_func(num_threads,N):
    set_num_threads(num_threads)
    which_thread_used=np.full(N,-1,dtype="int")
    for i in prange(N):
        which_thread_used[i]=get_thread_id()
    arange=np.arange(which_thread_used.shape[0])
    for i in range(num_threads):
        print("thread_id=",i," : i=",arange[which_thread_used==i])
num_threads=8
parallel_func(num_threads,100)
->
thread_id= 0  : i= [ 0  1  2  3  4  5  6  7  8  9 10 11 12]
thread_id= 1  : i= [13 14 15 16 17 18 19 20 21 22 23 24 25]
thread_id= 2  : i= [26 27 28 29 30 31 32 33 34 35 36 37 38]
thread_id= 3  : i= [39 40 41 42 43 44 45 46 47 48 49 50 51]
thread_id= 4  : i= [52 53 54 55 56 57 58 59 60 61 62 63]
thread_id= 5  : i= [64 65 66 67 68 69 70 71 72 73 74 75]
thread_id= 6  : i= [76 77 78 79 80 81 82 83 84 85 86 87]
thread_id= 7  : i= [88 89 90 91 92 93 94 95 96 97 98 99]

3. 並列化に伴っての注意(別にnumbaに限らない)

内容が割と釈迦に説法だと思います。最初自分も躓いたため、書いておきます。
以下はrandom_xの最大値を(解説用にわざわざ)求めています。
通常のpythonの時の感覚で書くと、

random_x=np.random.random(1000)
def find_max(arr):
    max_x=0
    for i in range(arr.shape[0]):
        if max_x<arr[i]:
            max_x=arr[i]
    return max_x
print(find_max(random_x))
->0.9995210476360222

まあ通常のpythonであれば普通に動きますが、これをそのままparallel=Trueにすると

@njit(parallel=True)
def jit_find_max(num_threads,arr):
    max_x=0
    for i in prange(arr.shape[0]):
        if max_x<arr[i]:
            max_x=arr[i]
    return max_x
num_threads=8
print(jit_find_max(num_threads,random_x))
->0.0

明らかに変になります。
これは同時に同じ場所に複数のスレッドが書き込もうとするためです(たぶん)。
ちなみに、parallel=Trueをとると普通に動きます。というのもparallel=Trueがないとprangeがただのrangeになるため。
また、そもそも@njitすらつけないと、prangeはpythonのrangeになるため、そのままpythonのコードとして動きます(ありがたすぎ)。
以下のようにすると動くようになる。

@njit(parallel=True)
def parallel_find_max(num_threads,arr):
    save=np.zeros(num_threads,dtype="float64")
    for thread_id in prange(num_threads):
        for i in range(thread_id,arr.shape[0],num_threads):
            if save[thread_id]<arr[i]:
                save[thread_id]=arr[i]
    return np.max(save)
print(parallel_find_max(num_threads,random_x))
->0.9995210476360222

prangeの外で定義した変数にはprangeの中で値を変えようとしない方がいい。
計算結果を入れておく場所としては、prangeの外で1次元目の長さがnum_threadsなnumpy.arrayを用意しておくといい。
とにかくスレッドごとに値を入れる場所を分ける。

もしくは

@njit(parallel=True)
def parallel_find_max_2(num_threads,arr):
    save=np.zeros(num_threads,dtype="float64")
    for thread_id in prange(num_threads):
        max_x=0
        for i in range(thread_id,arr.shape[0],num_threads):
            if max_x<arr[i]:
                max_x=arr[i]
        save[thread_id]=max_x
    return np.max(save)
print(parallel_find_max_2(num_threads,random_x))
->0.9995210476360222

max_x=0ように、prangeの内部で定義したものはスレッド間で干渉しない。
ただし、見てわかるように結局スレッド間での比較(np.max(save))がしたいので、
外側に保存用の配列(save)を作らなければならないのは変わらない。

4. 非対応なものを使いたい時

pythonでは様々パッケージを使用できることが利点だと思うが、numbaではnumpy以外非対応。
基本的には早くしたい部分のみ、関数で囲ってnjitするが、どうしても使いたい時もある。
例えばnumbaではそのままのdictは引数として取れないが、numba funcの中で作ったものなら引数にできる。
ならデータを'np.load'とかpickleでとってきたいが、当然非対応。

from numba import objmode
with objmode(data='int8[:,:]'):
    data=np.load(f'cache_folder/cache.npy')

objmode()の括弧の中は型を指定。上記ならint8の2次元配列。
objmodeの中で新しく変数を定義して外で使う時のみ型を指定する。
当然並列化ができず(GILがある状態)、jitでもないため、最終奥義というか、なんというか

5. 型推論について

numbaにはデコレータの部分で型を指定できるが、基本的にはしなくても動くし、速度もあまり変わらないらしい(下の記事にて)
ただしnp.zeros,np.ones,np.empty,np.fullなどあー型わかんなそうだなってものにはdtype="~~"をつける。
また、returnにはすべての条件で同じ型の返り値で統一する。

6. 非対応関数、オプション

以下の通り

基本的にパラメータはサポートされていないです。ほぼaxisが使えないのが痛いですが、自力でどうにかなることがほとんどだと思います。
配列を結合したい時に、np.concatenateを使いますが、注意点あり
printはf文字列とかが使えないですが、カンマ区切りが使える(print("a=",a)みたいな)ので全然問題ないです。
エラーについて
実装されていない関数を使用すると、以下のようなエラーが出るのですぐわかる。

Use of unsupported NumPy function 'numpy.tile' or unsupported use of the function.

また、使えないパラメータを設定すると

No implementation of function Function(<function argsort at 0x1058bca60>) found for signature:
 >>> argsort(array(int64, 3d, C), axis=Literal[int](2))

7. 多次元配列のインデックス

配列x=np.arange(10000).reshape((10,10,10,10))があったとする(4次元)
そのインデックスをnumba内で指定しようとする時、癖があるので注意。
x[1] → 流石にok
x[:,1] → 流石にok
x[1,:,2] → 流石にok
x[[1,2]] → ERROR、indexの指定にlistは使えない → ndarrayにするとok
x[np.array([1,2])] → (汚いけど)ok
x[:,np.array([1,2])] → (汚いけど)ok
x[1,:,np.array([1,2])] → ERROR、間に:が挟まるとなんかerror
x[np.array([1,2]),:,np.array([1,2])] → ERROR、同上

そもそもnumpyでもx[1,:,[1,2]]は返り値のshapeとして想像する(10,2,10)ではなく(2,10,10)が帰ってくる(記事の最後に記述)。
ので、あまり使いたくない。そんなこと言っても使う時は来るので、ゴリ押しだが、
np.transpose(x[1],(1,0,2))[np.array([1,2])]がnumpyのa[1,:,[1,2]]と同じでnumbaで使用可。
ただし、shapeは(10,2,10)ではなく(2,10,10)なので(記事の最後に記述)、shape=(10,2,10)を得るには
np.transpose(np.transpose(x[1],(1,0,2))[np.array([1,2])],(1,0,2))

8. 多次元配列をboolでfilter

多次元配列のboolをfilterとして用いることができません。

@njit(parallel=True,error_model="numpy")
def func(n):
    num=np.arange(n**2).reshape((n,n))
    TF=(np.arange(n**2)%2==1).reshape((n,n))
    return num[TF]
func(4)
->エラー

9. dictについて

上でも書いたが、pythonのdictをそのまま入れるとエラーになる。
ので、dictはnumba funcの中で作る必要がある。

sample_dict={i:j for i,j in enumerate(np.random.random(10))}
@njit
def jit_dict_func(sample_dict):
    return sample_dict[1]
jit_dict_func(sample_dict)
->エラー

@njit
def make_dict():
    return {i:j for i,j in enumerate(np.random.random(10))}

maked_in_jit_dict=make_dict()
jit_dict_func(maked_in_jit_dict)
->0.03346689489240351

もしくはコメントにて教えていただいたが、numba.typed.Dict()を使用するとjitの外からnumbaのdictが作れる。

これまたコメントにて教えていただいたが、dictはかなーり遅い。ndarray比15倍ほど遅いかなと。
まとめてでっかいndarrayを放り込んで管理するとか、そういう用途に使用していきたい。

NumbaではNumPy配列が最速、list は(速いのを選べば)かなり速い、dictとsetは遅い、という傾向のようです。

確認用コード(折りたたみ)
@njit
def make_dict(base_arr):
    return {i:base_arr[i] for i in range(base_arr.shape[0])}

@njit
def func(index_arr,dict_or_arr):
    num=0
    for i in range(index_arr.shape[0]):
        num+=dict_or_arr[index_arr[i]]
    return num

n=10000000
index_arr=np.arange(n)
np.random.shuffle(index_arr)
jit_ndarray=np.random.random(n)
jit_dict=make_dict(jit_ndarray)
print(func(index_arr,jit_dict),func(index_arr,jit_ndarray))
-> 5000079.66826389 5000079.66826389

%%timeit
func(index_arr,jit_dict)
-> 601 ms ± 14.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
func(index_arr,jit_ndarray)
-> 38.6 ms ± 1.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

10. classについて

これもpythonそのままのclassは使用できない。まあこれはselfの中身を型推論するのが無謀だろうからしょうがない。
numbaのclassとしてjitclassがある。

ただ、pandasとかと一緒に使用したいだろうから@staticmethodは必須感があるし、
もしnumba funcのみでも、現在parallel,error_modelが使えない(無理やりする方法はある?)ので、
まだ時期じゃないのかなと
特にerror_modelが使用できないは結構致命的だと思う。
何せゼロディビジョンエラーをinfにしてくれず、そこで計算が止まってしまう。
流石に計算用途では厳しいかと
一応、無理やりjitclassでparallel。試したことはない。

11. numba_progress

高速にするような計算はやっぱりプログレスバーが欲しい。進んでる様子が見られないと困る。
でも並列化してる時とかはatomic演算子がないと作るのむずかしいなと思っていたら作ってくれている人がいる。

Numba-progress.png

こんな感じ、すごい

以下のように使用する

from numba import njit, prange
from numba_progress import ProgressBar

num_iterations = 100

@njit(nogil=True, parallel=True)
def numba_function(num_iterations, progress_proxy):
    for i in prange(num_iterations):
        #<DO CUSTOM WORK HERE>
        progress_proxy.update(1)

with ProgressBar(total=num_iterations) as progress:
    numba_function(num_iterations, progress)

tqdmが元になっているので、tqdmのオプションがほぼ全部(?)使用できる。
個人的によく使用するのは、

  • total=int : 実行回数
  • dynamic_ncols=bool : 表示する領域に合わせてくれる()
  • disable=bool : nohupとかでバーを表示したくない時に
  • bar_format : そのまま、
  • leave=bool : 計算が終わった時にバーを残すか

11.1. loggerに進捗を出したい

自作のものです。もしかしたらnumba_progressにすでに機能があるのかもですが、

12. atomic演算

上のnumba_progressで、どうやって作ってるのかを覗いてみた時に、atomic_addなどが実装時てあるのを発見した。

from numba_progress.numba_atomic import atomic_add, atomic_xchg,atomic_min
@njit(parallel=True)
def sample_atomic(N):
    count=np.zeros(1,dtype="int64")
    for i in prange(N):
        atomic_add(count,0,1)
    print("atomic_add N count : ",count,count[0])
sample_atomic(1000)
atomic_add N count :  [1000] 0

見みて分かるように、配列のままだとちゃんとカウントできてるが外に出すと0になっている。
まあ一応使えそうな感じになっているので、使うときに試行錯誤すべし。
ただ、addとかはint限定だったりなど、色々制限があるので注意。以下を要確認。

13. numbaとnumpyで動作が違う

とっきどきあります。とは言っても今覚えてるのは以下のみ

a=np.arange(5)
a[1:]+=a[:-1]
print(f"python {a}")
@njit
def jit_a():
    a=np.arange(5)
    a[1:]+=a[:-1]
    return a
print(f"numba {jit_a()}")
python [0 1 3 5 7]
numba [ 0  1  3  6 10]

numbaは新規で配列を作らずに、計算してるからこうなるのかな?と思います。

14. 高速化について自分が知らなかったこと(numbaに限らず)

loop

pythonでは使用しない方が速いとされるforなどループですが、numbaでは使用しても問題なく高速です。
むしろnumpyのやり方では無駄な計算が必要になるケースでは積極的に使った方がいいかも。

Memory Layout

以下記事のMemory Layout,ストライドの部分はかなり速度に影響してくる場合がある。
できれば常に意識してコードを書きたい。
特に最後の「何度も計算する必要がある場合、速度の低下を防ぐためにはビューで計算するのではなく、コピーを作成して計算した方が速くなる時ときがあります。」はよく効いたことがあった。

思ったより効果がなさそう

  • a=a+1a+=1にしなければならないと思っていたが、そんなことなかった。
  • むしろ2回に分けて+=arr[i],/=2みたいにする場合、遅くなる場面も
確認用コード(折りたたみ)
@njit(parallel=True)
def func_a(num_threads,arr):
    save=np.ones((num_threads,arr.shape[1]),dtype="float64")
    for thread_id in prange(num_threads):
        for i in range(thread_id,arr.shape[0],num_threads):
            save[thread_id]=(save[thread_id]+arr[i])/2
            #save[thread_id]=save[thread_id]+arr[i]#ここコメント外して1回に
    return save[0,0]
    
@njit(parallel=True)
def func_b(num_threads,arr):
    save=np.ones((num_threads,arr.shape[1]),dtype="float64")
    for thread_id in prange(num_threads):
        for i in range(thread_id,arr.shape[0],num_threads):
            save[thread_id]+=arr[i]
            save[thread_id]/=2#ここコメントアウトで1回に
    return save[0,0]
random_x=np.random.random((10000000,5))
num_threads=8
print(func_a(num_threads,random_x))#一旦コンパイル
print(func_b(num_threads,random_x))#一旦コンパイル

%%timeit
func_a(num_threads,random_x)
->270 ms ± 18.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
func_b(num_threads,random_x)
->532 ms ± 50.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • 追記予定

14.1 . 高速化について多分そうっぽいこと

以下は特に検証用コードを用意して確かめたものではないので、話半分で、
正直わかんないで使ってるところがあるので、教えて欲しいという立場です

nogil=Trueについて

多分速度アップの効果はなさそう。色々試したけどアップしたことはなし。
numba prangeで回す場合、ないほうがほんのちょっとだけ早かった。
pythonのthreadingとかmultiprocessingで回す人向けかなと。

fastmath=Trueについて

的当になんでもつけるとなぜか遅くなる。
明らかに速くなりそうな、floatの計算が絡む場所のみにつけた方がいいみたい。
また、SVMLに対応しないとなんの意味もないと明言されている。

numba -sicc_rtが入っていなければconda install -c numba icc_rtでインストール。
ちなみにM1 macはインストールできない、intelの機能だしね。
SVMLのコントロールの仕方は下でかく。

各種コンパイルオプション

ざっと見た感じ、速度に関わりそうなオプションは以下の通り

  • OPT

    • 0,1,2,3,'max'
    • どんくらい最適化するか見たいな
    • 標準で3、3から'max'への変化は正直わからなかった
  • ENABLE_AVX

    • 0,1で指定
    • AVXっていうSIMDの新しめのを使用するか
    • numba -sでCPU Featuresににいろいろ書いてあるのが使えるらしい
    • M1 macはなんも書いてない
  • DISABLE_INTEL_SVML

    • 0,1で指定
    • そのままの意味
    • icc_rtを入れていないと意味なし、多分標準で0になっている
  • DISABLE_JIT

    • jitするかをここで決められる
    • @njitをつけていても、問答無用でpythonにしてくれる
    • .pyで別ファイルに分けて、そこにjit関数があっても問答無用でpythonにしてくれる
  • THREADING_LAYER

    • 'default','safe','forksafe','threadsafe','tbb','omp','workqueue'
    • safeはtbbのインストールが必要らしい
    • tbb,ompはそれぞれインストールが必要、conda install openmp tbb
    • M1macではopenmpはダメでした
    • 何が変わってるのかよくわかってないので、はい
    • 速度感は、自前のコードでは以下のようになった、多分コードによるんだろうと思う
THREADING_LAYER
default 34.813884 s
safe 36.765354 s
forksafe 36.809115 s
threadsafe 37.991987 s
tbb 32.778061 s
omp 29.010565 s
workqueue 47.317840 s

各種コンパイルオプションコントロール

変え方はいくつかあって、

環境変数

OPTだったら前にNUMBA_をつけて、NUMBA_OPT = 'max'という環境変数を定義する。
THREADING_LAYERならNUMBA_THREADING_LAYER = 'omp'のように

.numba_config.yaml

.numba_config.yamlという名のファイルを作業ディレクトリに作って、そこに以下のように書く
小文字にして、NUMBA_を取っ払う
jupyterlabでは最初隠しファイルが見えないので設定が必要。下にまとめる。

disable_jit: 0 # 0 -> jit enable
opt: "max" #0~3, "max"
threading_layer: "default" # "safe","forksafe","threadsafe","tbb","omp","workqueue"
enable_avx: 0
disable_intel_svml: 1 # 0 -> enable
numba.config.

関数の内部からもnumba.config.OPTの形でアクセスできる。
なんかこっちは大文字で、NUMBA_はなし。
アクセスなので、値を読むこともできるし、入れることもできる。
ただし、OPTのみ注意、下の例のようにnumba.config._OptLevel()に入れる必要あり

import numba
numba.config.OPT=numba.config._OptLevel("max")
numba.config.THREADING_LAYER="default"
numba.config.DISABLE_JIT=0

まあ、.numba_config.yamlnumba.config.が使いやすそうかなと
そもそもそこまでする必要があるのかと。
他にもoverride_configみたいなのもあるらしい。以下ページ参照

他ページ

jupyterlabで隠しファイルを見る設定

だけだとダメだったので、
「JupyterLab の メニュー → [表示] → [Show hidden files] のチェック」

15. numbaのバグ(?)

おそらくnumbaのバグで、自身が遭遇した者を羅列する場所。
当然バージョンによっては存在しないものもあるだろうから、無くなってたら開発者に感謝

  1. np.argpartition/np.partitionで
  • numba=0.60.0
  • エラー:kth out of bounds
  • なんだか配列の長さと同じkを設定するとエラーになるらしい(まあいいや)

感想というか、

驚異的に速くなるので本当にありがたいです。
早くなった上で、それでも許容できない時間がかかることが多いので、さらにアルゴリズムの見直しとか、無駄な処理していないかのチェックとかは必要になりますが、そもそもnumbaがなければそれ以前なので、本当にありがたい

気に食わないこと(numbaではなくnumpy)

numpyの使用で、以下のようなものがあります。

a=np.random.random((10,11,12,13))
a[:,:,[1,2,3]].shape
->(10, 11, 3, 13)  # ok
a[1,:,[1,2,3]].shape
->(3, 11, 13) ## は?? ((11,3,13)になってほしかった)

何が気に食わないかというと、a[1,:,[1,2,3]].shapeについてです。
np.array([1,2,3])が指定しているのは元の長さが12の部分です。

ので、最初のshape=(10,11,12,13)の1次元目の10がなくなり、12の部分が3で(11,3,13)が欲しい。
でも帰ってきたのは(3, 11, 13)です。なんか3が先頭に行ってしまっているんですね。
どうやら指定するindexの間に:を挟むとこういったことになるようです。

a[:,[1,2],:].shape
->(10, 2, 12, 13)  # ok
a[:,[1,2],:,1].shape
->(2, 10, 12)  # 変になってる

どなたかなぜこうなっているのか教えて欲しいです。

9
5
5

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?