前置き
numpyの処理(argmax)がボトルネックになるような計算(otheroの強化学習)をしていたときに気づいたので共有する。意外と速度が違う。
結論
numpy.ndarrayのメンバ関数が使える処理はそれを使おう。また、1つの数字に対する操作はmathジュールの方が速い。
ベンチマーク
jupyter notebook上で次のベンチマークを行った。コメント行は結果を示す。明らかにnumpy.ndarrayの関数が速い。また、sqrtはmathモジュールの方が速い。
test.py
import numpy as np
import math
a = np.random.rand(100)
%timeit np.argmax(a)
# 1.48 µs ± 19.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit a.argmax()
# 668 ns ± 6.52 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit np.max(a)
# 3.32 µs ± 24.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit a.max()
# 2.82 µs ± 30.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.sum(a)
# 3.17 µs ± 34.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit a.sum()
# 2.44 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit math.sqrt(a.sum())
# 2.71 µs ± 39.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.sqrt(a.sum())
# 4.21 µs ± 36.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
配列サイズを大きくしたものも行った。差異は小さくなっているが、依然としてnumpy.ndarrayの関数が速い。maxは殆ど一緒になっている。
test.py
a = np.random.rand(10000)
%timeit np.argmax(a)
# 9.94 µs ± 536 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit a.argmax()
# 8.96 µs ± 355 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.max(a)
# 7.98 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit a.max()
# 8 µs ± 372 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.sum(a)
# 9.12 µs ± 193 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit a.sum()
# 7.89 µs ± 274 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each
%timeit math.sqrt(a.sum())
# 7.99 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.sqrt(a.sum())
# 9.92 µs ± 64.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)