Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

numpyの速度を最大限に活かすために

More than 1 year has passed since last update.

前置き

numpyの処理(argmax)がボトルネックになるような計算(otheroの強化学習)をしていたときに気づいたので共有する。意外と速度が違う。

結論

numpy.ndarrayのメンバ関数が使える処理はそれを使おう。また、1つの数字に対する操作はmathジュールの方が速い。

ベンチマーク

jupyter notebook上で次のベンチマークを行った。コメント行は結果を示す。明らかにnumpy.ndarrayの関数が速い。また、sqrtはmathモジュールの方が速い。

test.py
import numpy as np
import math

a = np.random.rand(100)

%timeit np.argmax(a)
# 1.48 µs ± 19.4 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit a.argmax()
# 668 ns ± 6.52 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%timeit np.max(a) 
# 3.32 µs ± 24.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit a.max() 
# 2.82 µs ± 30.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit np.sum(a) 
# 3.17 µs ± 34.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit a.sum() 
# 2.44 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit math.sqrt(a.sum())
# 2.71 µs ± 39.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.sqrt(a.sum())
# 4.21 µs ± 36.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

配列サイズを大きくしたものも行った。差異は小さくなっているが、依然としてnumpy.ndarrayの関数が速い。maxは殆ど一緒になっている。

test.py
a = np.random.rand(10000)

%timeit np.argmax(a)
# 9.94 µs ± 536 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit a.argmax()
# 8.96 µs ± 355 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit np.max(a) 
# 7.98 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit a.max() 
# 8 µs ± 372 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit np.sum(a) 
# 9.12 µs ± 193 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit a.sum() 
# 7.89 µs ± 274 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each

%timeit math.sqrt(a.sum())
# 7.99 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.sqrt(a.sum())
# 9.92 µs ± 64.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
ryuoujisinta
最近機械学習に興味があり、趣味で勉強している。普段は数値計算をやっている。
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away