LoginSignup
10
10

More than 3 years have passed since last update.

PythonとNumbaで数値計算を高速化するときの知見

Last updated at Posted at 2020-08-30

研究でPythonプログラムの高速化のためにNumbaを使用しました.
実装にあたって,いろいろエラーが出てつまずいたので,その知見をサンプルとして共有します.

私の場合ですが,ルンゲクッタ法を使う粒子群最適化法の計算が,約2000秒→約60秒と33倍の高速化になりました.

※注意※
- Numbaの基本的な使い方は参考文献を参照してください.
- すべての使用例では,nopythonモードであるnjitを使用しています.
- Numbaは引数や戻り値の型を指定しなくても動作する場合がありますが,ここではすべて指定することを前提とします.

1. 環境

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.15.2
BuildVersion:   19C57

$ python -V
Python 3.8.5

$ pip freeze
numba==0.51.0
numpy==1.19.1

Numbaは以下のコマンドでインストール出来ます.

$ pip install numba

2. 使用例

2-1. 関数内でnp.emptyを使用する

numbaを使用する場合,関数内でnp.emptyを使用するとエラーが出ることがあります.
その場合,以下のように型を指定するとうまく動作しました.

main.py
import numpy as np
from numba import njit

@njit("f8[:,:]()")
def func():
    x = np.empty((1, 2), dtype=np.float64)
    return x

print(func())

2-2. 複数の戻り値を返す

複数の戻り値を返す場合,Tuple((i8, i8))のように書きます.
カッコが2重になっていることに注意が必要です.

main.py
import numpy as np
from numba import njit

@njit("Tuple((i8, i8))(i8, i8)")
def func(x, y):
    return x, y

print(func(1, 2))

2-3. 多次元のリストを扱う

Numbaで多次元のリストを扱う場合,f8[:,:]のように書きます.
2次元だからコロンが2つというわけではなくて,何次元でも2つでいいようです.

main.py
import numpy as np
from numba import njit

@njit("f8[:,:](f8[:,:])")
def func(x):
    return x ** 2

x = np.random.rand(5, 5)
print(func(x))

3. 終わり

Pythonの高速化はCythonやJuliaなど色々な方法がありますが,デコレーターを書くだけのNumbaによる方法は一番簡単なものだと思います.

クラスやジェネレーターが使えないなどの制約はありますが,ボトルネックを局所的に高速化するやり方であれば,比較的容易に実装できるものだと感じました.

4. 参考文献

10
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
10