この記事は
Pythonの高速化のためにCython, Numba, numexprを試したのでメモ.
詳しい解説などは無いです.
実行環境
- Ubuntu 16.04
- Python 3.5.4
- Anaconda 1.6.6
高速化を目指して
高速化したい処理
以下の処理を数多く行うので,これを高速化したい.
for ~
for ~
y = np.log(1.0 + np.exp(x))
numpyでそのまま
まずは元コードの場合.
import numpy as np
import time
if __name__ == "__main__":
x = np.ones((300, 300), dtype=np.float32)
# original
start = time.time()
for n in range(1000):
for i in range(1000):
y = np.log(1.0 + np.exp(x))
print("original took %f s" %(time.time() - start))
実行結果
original took 30.942069 s
Cython
処理を丸ごとCythonにした場合.
pyxファイルとsetup.pyは以下のようにしました.
変数,型定義などできることは高速化に必要なことはだいたいできているはず.
import numpy as np
cimport numpy as np
cimport cython
ctypedef np.float32_t FLOAT_t
def test1(np.ndarray[FLOAT_t, ndim=2] x):
cdef int n, i
cdef np.ndarray[FLOAT_t, ndim=2] y
for n in range(100):
for i in range(100):
y = np.log(1.0 + np.exp(x))
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
import numpy as np
setup(
name = "sample",
cmdclass = {"build_ext" : build_ext},
ext_modules = [Extension("sample", sources=["sample.pyx"])],
include_dirs = [np.get_include()]
)
### or
#from distutils.core import setup
#from Cython.Distutils import build_ext
#from Cython.Build import cythonize
#
#setup(
# cmdclass = {"build_ext" : build_ext},
# ext_modules = cythonize("sample.pyx")
# )
setup.pyを実行.
$ python setup.py build_ext --inplace
import numpy as np
import time
from sample import test1
if __name__ == "__main__":
x = np.ones((300, 300), dtype=np.float32)
# original
start = time.time()
for n in range(1000):
for i in range(1000):
y = np.log(1.0 + np.exp(x))
print("original took %f s" %(time.time() - start))
# Cython
start = time.time()
test1(x)
print("Cython took %f s" %(time.time() - start))
実行結果
original took 30.662536 s
Cython took 30.764751 s
いくつか試した結果,Cythonでforループやnumpy arrayのインデキシングは非常に高速になっていました.
しかしながら,その高速化が誤差と思えるほどにnp.log, np.expが遅かったようです.
Numba
Numbaの場合.
高速化したい関数の上に@jitとデコレータを書き加えるだけで良さそうでした.
https://myenigma.hatenablog.com/entry/2017/03/02/155433
import numpy as np
import time
from numba import jit
from sample import test1
@jit
def test2(x):
for n in range(100):
for i in range(100):
y = np.log(1.0 + np.exp(x))
if __name__ == "__main__":
x = np.ones((300, 300), dtype=np.float32)
# original
start = time.time()
for n in range(1000):
for i in range(1000):
y = np.log(1.0 + np.exp(x))
print("original took %f s" %(time.time() - start))
# Cython
start = time.time()
test1(x)
print("Cython took %f s" %(time.time() - start))
# Numba
start = time.time()
test2(x)
print("Numba took %f s" %(time.time() - start))
実行結果
original took 30.954107 s
Cython took 30.807106 s
Numba took 153.410239 s
よく分からないが劇遅になった.
何かコードの誤りがあるかもしれません...
numexpr
numexprの場合.
不思議な書き方だが,以下のように書くみたい.
https://github.com/pydata/numexpr
http://memoryfolder.hatenablog.com/entry/2016/11/23/181250
import numpy as np
import time
from numba import jit
import numexpr as ne
from sample import test1
@jit
def test2(x):
for n in range(100):
for i in range(100):
y = np.log(1.0 + np.exp(x))
if __name__ == "__main__":
x = np.ones((300, 300), dtype=np.float32)
# original
start = time.time()
for n in range(1000):
for i in range(1000):
y = np.log(1.0 + np.exp(x))
print("original took %f s" %(time.time() - start))
# Cython
start = time.time()
test1(x)
print("Cython took %f s" %(time.time() - start))
# Numba
start = time.time()
test2(x)
print("Numba took %f s" %(time.time() - start))
# numexpr
start = time.time()
for n in range(1000):
for i in range(1000):
y = ne.evaluate("log(1.0 + exp(x))").astype("float32")
print("numexpr took %f s" %(time.time() - start))
実行結果
original took 30.707380 s
Cython took 30.750478 s
Numba took 153.180645 s
numexpr took 3.371960 s
よく分からないが非常に速くなった.
部分的な高速化であればnumexprで十分そう.