3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Pythonの max はどれだけ時間がかかるのか?

Last updated at Posted at 2017-09-28

本日は

Numba でどれだけ高速が期待できるか, 他の機能と比べてどうかを調べてた時に偶然発見した話題です.

max

Pythonにおいて配列の最大値を求める一番単純なのは組み込みの max ですね.

def py_max(array):
    max_value = max(array)
    return max_value

こんなの関数化してなんの意味がと思われるでしょう.後で時間を計測したいからです.

自前で最大値を求めると次のようなコードになります:

def scratch_max(array):
    ret = array[0]
    for value in array:
        ret = value if ret < value else ret
    return ret

loopがあるとなんとなく Numba でJITコンパイルさせたくなります:

@jit
def scratch_max_jit(array):
    ret = array[0]
    for value in array:
        ret = value if ret < value else ret
    return ret

そういえば Numpy にも max あったっけか:

def np_max(array):
    max_value = np.max(array)
    return max_value

 実行時間の計測モジュール

次のような時間計測モジュールをつくります:

measure.py
from functools import wraps
import time


def measure_time(function):
    @wraps(function)
    def measure_target(*args, **kwargs):
        start = time.time()
        ret = function(*args, **kwargs)
        end = time.time()
        return function.__name__, ret, end-start
    return measure_target

例えば次のように時間を計測したい関数に対してmeasure_timeによって修飾することで,修飾された関数は関数名,関数戻り値,実行時間の組みを返すようになります.

example.py
from measure import measure_time


@measure_time
def py_max(array):
    max_value = max(array)
    return max_value


def main():
    print(py_max([1, 4, 2, 8, 5]))

if __name__ == '__main__':
    main()

実行すると次のように結果が出力されます.

$ python example.py
('py_max', 8, 2.86102294921875e-06)

時間計測してみましょう

benchmark.py
from measure import measure_time
from numba import jit
import numpy as np
from array import array


@measure_time
def py_max(array):
    max_value = max(array)
    return max_value


@measure_time
def np_max(array):
    max_value = np.max(array)
    return max_value


@measure_time
@jit('int64(int64[:])')
def scratch_max_with_annotate(array):
    ret = array[0]
    for value in array:
        ret = value if ret < value else ret
    return ret


@measure_time
@jit
def scratch_max_jit(array):
    ret = array[0]
    for value in array:
        ret = value if ret < value else ret
    return ret


@measure_time
def scratch_max(array):
    ret = array[0]
    for value in array:
        ret = value if ret < value else ret
    return ret


@measure_time
def get_shuffled_list(arr):
    return shuffle_list(arr)


def benchmark(arr):
    print('before shuffle', type(arr))
    np.random.shuffle(arr)
    print('after shuffle', type(arr))
    print(scratch_max(arr))
    print(scratch_max_jit(arr))
    # print(scratch_max_with_annotate(arr))
    print(py_max(arr))
    print(np_max(arr))


def main():
    N = 100000000
    benchmark(list(range(N)))

if __name__ == '__main__':
    main()

実行例

実行結果はこちら(MacBook12-inch 初代)

$ python benchmark.py
before shuffle <class 'list'>
after shuffle <class 'list'>
('scratch_max', 99999999, 55.70367407798767)
('scratch_max_jit', 99999999, 50.232807874679565)
('py_max', 99999999, 31.76629090309143)
('np_max', 99999999, 75.61286282539368)

んー自前で書いた関数に@jitしてもさほど意味が...というか np.max 遅い.どういうことだ.悶々しながらいろいろトライした結果引数の型が結構大事ということでした.

引数の型を変更する.

benchmark の入力引数を変えてみましょう.

array.array

benchmark.py(array.array)

# 中略

def main():
    N = 100000000
    bench_mark(array('L', list(range(N))))

if __name__ == '__main__':
    main()

どうだ!

$ python benchmark.py
before shuffle <class 'array.array'>
after shuffle <class 'array.array'>
('scratch_max', 99999999, 27.1295268535614)
('scratch_max_jit', 99999999, 0.28397393226623535)
('py_max', 99999999, 4.903533935546875)
('np_max', 99999999, 0.1654500961303711)

かなりましになりました.
array モジュールの説明

基本的な値 (文字、整数、浮動小数点数) のアレイ (array、配列) をコンパクトに表現できるオブジェクト型を定義しています。

といっているだけありますね.

benchmark の引数をnumpyの配列にしたらどうなるの?

numpy.ndarray

やってみましょう. scratch_max_with_annotate 関数を有効にします.
@jit にヒントを追加すれば速くなるはずなので.

benchmark.py(numpy.ndarray)

# 中略

def bench_mark(arr):
    print('before shuffle', type(arr))
    np.random.shuffle(arr)
    print('after shuffle', type(arr))
    print(scratch_max(arr))
    print(scratch_max_jit(arr))
    print(scratch_max_with_annotate(arr))
    print(py_max(arr))
    print(np_max(arr))

def main():
    N = 100000000
    benchmark(np.arange(N))

if __name__ == '__main__':
    main()

どうだ!

$ python benchmark.py
before shuffle <class 'numpy.ndarray'>
after shuffle <class 'numpy.ndarray'>
('scratch_max', 99999999, 29.86492681503296)
('scratch_max_jit', 99999999, 0.170396089553833)
('scratch_max_with_annotate', 99999999, 0.11565899848937988)
('py_max', 99999999, 10.55059814453125)
('np_max', 99999999, 0.10283994674682617)

JITコンパイルで高速にできること, @jit にヒントを追加することで速くなることも確認できました.わずかながら np_max 関数が最速よね.また py_max だけで見ると array モジュールを入力とした方が良いみたい.

比較表

type(arr)\func scratch_max scratch_jit py_max np_max scratch_max_with_annotate
'list' 55.70 50.232 31.76 75.612 NAN
'array.array' 27.12 0.28 4.90 0.16 NAN
'numpy.ndarray' 29.86 0.17 10.55 0.10 0.11

テーブルの作り方は Qiitaのテーブルの書き方についてまとめた
を参考にしました.とてもよくまとまっていて助かっています.

わかったこと

  • 何も考えず最大値求める時は組み込みのmaxではなくnp.maxで良いねという考えを改めないといけない.
  • むしろ逆に遅くなる場合もあること.
  • 入力するデータ構造に対応して適切な関数を選択すべきこと.
  • 書くのに結構エネルギーがいること.チカレタ...
3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?