本日は
Numba でどれだけ高速が期待できるか, 他の機能と比べてどうかを調べてた時に偶然発見した話題です.
max
Pythonにおいて配列の最大値を求める一番単純なのは組み込みの max
ですね.
def py_max(array):
max_value = max(array)
return max_value
こんなの関数化してなんの意味がと思われるでしょう.後で時間を計測したいからです.
自前で最大値を求めると次のようなコードになります:
def scratch_max(array):
ret = array[0]
for value in array:
ret = value if ret < value else ret
return ret
loopがあるとなんとなく Numba でJITコンパイルさせたくなります:
@jit
def scratch_max_jit(array):
ret = array[0]
for value in array:
ret = value if ret < value else ret
return ret
そういえば Numpy にも max
あったっけか:
def np_max(array):
max_value = np.max(array)
return max_value
実行時間の計測モジュール
次のような時間計測モジュールをつくります:
from functools import wraps
import time
def measure_time(function):
@wraps(function)
def measure_target(*args, **kwargs):
start = time.time()
ret = function(*args, **kwargs)
end = time.time()
return function.__name__, ret, end-start
return measure_target
例えば次のように時間を計測したい関数に対してmeasure_time
によって修飾することで,修飾された関数は関数名,関数戻り値,実行時間の組みを返すようになります.
from measure import measure_time
@measure_time
def py_max(array):
max_value = max(array)
return max_value
def main():
print(py_max([1, 4, 2, 8, 5]))
if __name__ == '__main__':
main()
実行すると次のように結果が出力されます.
$ python example.py
('py_max', 8, 2.86102294921875e-06)
時間計測してみましょう
from measure import measure_time
from numba import jit
import numpy as np
from array import array
@measure_time
def py_max(array):
max_value = max(array)
return max_value
@measure_time
def np_max(array):
max_value = np.max(array)
return max_value
@measure_time
@jit('int64(int64[:])')
def scratch_max_with_annotate(array):
ret = array[0]
for value in array:
ret = value if ret < value else ret
return ret
@measure_time
@jit
def scratch_max_jit(array):
ret = array[0]
for value in array:
ret = value if ret < value else ret
return ret
@measure_time
def scratch_max(array):
ret = array[0]
for value in array:
ret = value if ret < value else ret
return ret
@measure_time
def get_shuffled_list(arr):
return shuffle_list(arr)
def benchmark(arr):
print('before shuffle', type(arr))
np.random.shuffle(arr)
print('after shuffle', type(arr))
print(scratch_max(arr))
print(scratch_max_jit(arr))
# print(scratch_max_with_annotate(arr))
print(py_max(arr))
print(np_max(arr))
def main():
N = 100000000
benchmark(list(range(N)))
if __name__ == '__main__':
main()
実行例
実行結果はこちら(MacBook12-inch 初代)
$ python benchmark.py
before shuffle <class 'list'>
after shuffle <class 'list'>
('scratch_max', 99999999, 55.70367407798767)
('scratch_max_jit', 99999999, 50.232807874679565)
('py_max', 99999999, 31.76629090309143)
('np_max', 99999999, 75.61286282539368)
んー自前で書いた関数に@jit
してもさほど意味が...というか np.max
遅い.どういうことだ.悶々しながらいろいろトライした結果引数の型が結構大事ということでした.
引数の型を変更する.
benchmark
の入力引数を変えてみましょう.
array.array
# 中略
def main():
N = 100000000
bench_mark(array('L', list(range(N))))
if __name__ == '__main__':
main()
どうだ!
$ python benchmark.py
before shuffle <class 'array.array'>
after shuffle <class 'array.array'>
('scratch_max', 99999999, 27.1295268535614)
('scratch_max_jit', 99999999, 0.28397393226623535)
('py_max', 99999999, 4.903533935546875)
('np_max', 99999999, 0.1654500961303711)
かなりましになりました.
array モジュールの説明で
基本的な値 (文字、整数、浮動小数点数) のアレイ (array、配列) をコンパクトに表現できるオブジェクト型を定義しています。
といっているだけありますね.
benchmark
の引数をnumpyの配列にしたらどうなるの?
numpy.ndarray
やってみましょう. scratch_max_with_annotate
関数を有効にします.
@jit
にヒントを追加すれば速くなるはずなので.
# 中略
def bench_mark(arr):
print('before shuffle', type(arr))
np.random.shuffle(arr)
print('after shuffle', type(arr))
print(scratch_max(arr))
print(scratch_max_jit(arr))
print(scratch_max_with_annotate(arr))
print(py_max(arr))
print(np_max(arr))
def main():
N = 100000000
benchmark(np.arange(N))
if __name__ == '__main__':
main()
どうだ!
$ python benchmark.py
before shuffle <class 'numpy.ndarray'>
after shuffle <class 'numpy.ndarray'>
('scratch_max', 99999999, 29.86492681503296)
('scratch_max_jit', 99999999, 0.170396089553833)
('scratch_max_with_annotate', 99999999, 0.11565899848937988)
('py_max', 99999999, 10.55059814453125)
('np_max', 99999999, 0.10283994674682617)
JITコンパイルで高速にできること, @jit
にヒントを追加することで速くなることも確認できました.わずかながら np_max
関数が最速よね.また py_max
だけで見ると array
モジュールを入力とした方が良いみたい.
比較表
type(arr)\func | scratch_max | scratch_jit | py_max | np_max | scratch_max_with_annotate |
---|---|---|---|---|---|
'list' | 55.70 | 50.232 | 31.76 | 75.612 | NAN |
'array.array' | 27.12 | 0.28 | 4.90 | 0.16 | NAN |
'numpy.ndarray' | 29.86 | 0.17 | 10.55 | 0.10 | 0.11 |
テーブルの作り方は Qiitaのテーブルの書き方についてまとめた
を参考にしました.とてもよくまとまっていて助かっています.
わかったこと
- 何も考えず最大値求める時は組み込みの
max
ではなくnp.max
で良いねという考えを改めないといけない. - むしろ逆に遅くなる場合もあること.
- 入力するデータ構造に対応して適切な関数を選択すべきこと.
- 書くのに結構エネルギーがいること.チカレタ...