13
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Numbaを使ってNumpy処理を高速化

Posted at

以下のような処理を行いたい

def transfer_balls(box, mm):
    # ボールをmm回シャッフルする
    for i in range(mm):
        # 箱を2つランダムに選ぶ(同じ箱を選ばないようにする)
        flg = True
        while flg:
            x1, y1 = np.random.randint(0, 100, (2,))
            x2, y2 = np.random.randint(0, 100, (2,))
            if x1 == x2 and y1 == y2:
                continue
            break

        nt = box[y2, x2]
        n1 = box[y1, x1]
        n2 = box[y2, x2]

        # ボールを移動させる
        if n1 - nt >= 1:
            n1 -= nt
            n2 += nt
        else:
            n2 = n2 + n1 - 1
            n1 = 1

        box[y1, x1] = n1
        box[y2, x2] = n2

    return box

箱をランダムに2つ選んでボールを移動させる、というシミュレーションを行いたい。
試行回数(mm)が100回程度なら良いのですが、これを100万回や1000万回にすると、めちゃくちゃ時間がかかってしまいます。

Numbaを使おう

どうやらNumbaを利用すると簡単に処理を高速化できるらしい。使い方↓

from numba import jit


@jit('(戻り値の型)((引数1の型),(引数2の型),…)', nopython=True)
def hogehoge(bar, foo):
	# 時間のかかりそうな処理
	return fuga

戻り値と引数の型を指定してあげて、デコレータを書くだけです。簡単ですね。
型指定方法はこちらに記載してあります。
ちなみに2次元配列の場合は 型[:,:] 3次元配列の場合は 型[:,:,:] のように書いてあげると動作します。

今回は
戻り値:numpyの2次元配列(中身はfloat64)
引数1(box):numpyの2次元配列(中身はfloat64)
引数2(mm):int32
なので、以下のように記述します。

@jit('f8[:,:](f8[:,:],i4)', nopython=True)
def transfer_balls(box, mm):
    ###
    return box

処理速度を計測

今回、mm = 400万 として計測してみました。
何もなし:147秒
jitあり:6秒
実に20倍近く早くなりました。正確に計測していませんが、mmを2500万にした場合、何もなしだと10~20分程度かかりましたが、jitありだと1分足らずで終わってしまいました。

簡単に処理速度を上げることができるので、Numbaオススメです。

13
15
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?