12
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Python高速化のためにいろいろ試した(Cython, Numba, numexpr)

Last updated at Posted at 2018-09-19

この記事は

Pythonの高速化のためにCython, Numba, numexprを試したのでメモ.
詳しい解説などは無いです.

実行環境

  • Ubuntu 16.04
  • Python 3.5.4
  • Anaconda 1.6.6

高速化を目指して

高速化したい処理

以下の処理を数多く行うので,これを高速化したい.

for ~
   for ~
      y = np.log(1.0 + np.exp(x))

numpyでそのまま

まずは元コードの場合.

looptest.py
import numpy as np
import time

if __name__ == "__main__":

   x = np.ones((300, 300), dtype=np.float32)

   # original
   start = time.time()
   for n in range(1000):
       for i in range(1000):
           y = np.log(1.0 + np.exp(x))
   print("original took %f s" %(time.time() - start))

実行結果

original took 30.942069 s

Cython

処理を丸ごとCythonにした場合.
pyxファイルとsetup.pyは以下のようにしました.
変数,型定義などできることは高速化に必要なことはだいたいできているはず.

sample.pyx
import numpy as np
cimport numpy as np
cimport cython

ctypedef np.float32_t FLOAT_t

def test1(np.ndarray[FLOAT_t, ndim=2] x):
   cdef int n, i
   cdef np.ndarray[FLOAT_t, ndim=2] y

   for n in range(100):
      for i in range(100):
         y = np.log(1.0 + np.exp(x))
setup.py
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
import numpy as np

setup(
     name = "sample",
     cmdclass = {"build_ext" : build_ext},
     ext_modules = [Extension("sample", sources=["sample.pyx"])],
     include_dirs = [np.get_include()]
     )

### or
#from distutils.core import setup
#from Cython.Distutils import build_ext
#from Cython.Build import cythonize
#
#setup(
#      cmdclass = {"build_ext" : build_ext},
#      ext_modules = cythonize("sample.pyx")
#      )

setup.pyを実行.

$ python setup.py build_ext --inplace
looptest.py
import numpy as np
import time

from sample import test1

if __name__ == "__main__":

   x = np.ones((300, 300), dtype=np.float32)

   # original
   start = time.time()
   for n in range(1000):
       for i in range(1000):
           y = np.log(1.0 + np.exp(x))
   print("original took %f s" %(time.time() - start))

   # Cython
   start = time.time()
   test1(x)
   print("Cython took %f s" %(time.time() - start))

実行結果

original took 30.662536 s
Cython took 30.764751 s

いくつか試した結果,Cythonでforループやnumpy arrayのインデキシングは非常に高速になっていました.
しかしながら,その高速化が誤差と思えるほどにnp.log, np.expが遅かったようです.

Numba

Numbaの場合.
高速化したい関数の上に@jitとデコレータを書き加えるだけで良さそうでした.
https://myenigma.hatenablog.com/entry/2017/03/02/155433

looptest.py
import numpy as np
import time
from numba import jit

from sample import test1

@jit
def test2(x):
   for n in range(100):
      for i in range(100):
         y = np.log(1.0 + np.exp(x))

if __name__ == "__main__":

   x = np.ones((300, 300), dtype=np.float32)

   # original
   start = time.time()
   for n in range(1000):
       for i in range(1000):
           y = np.log(1.0 + np.exp(x))
   print("original took %f s" %(time.time() - start))

   # Cython
   start = time.time()
   test1(x)
   print("Cython took %f s" %(time.time() - start))

   # Numba
   start = time.time()
   test2(x)
   print("Numba took %f s" %(time.time() - start))

実行結果

original took 30.954107 s
Cython took 30.807106 s
Numba took 153.410239 s

よく分からないが劇遅になった.
何かコードの誤りがあるかもしれません...

numexpr

numexprの場合.
不思議な書き方だが,以下のように書くみたい.
https://github.com/pydata/numexpr
http://memoryfolder.hatenablog.com/entry/2016/11/23/181250

looptest.py
import numpy as np
import time
from numba import jit
import numexpr as ne

from sample import test1

@jit
def test2(x):
   for n in range(100):
      for i in range(100):
         y = np.log(1.0 + np.exp(x))

if __name__ == "__main__":

   x = np.ones((300, 300), dtype=np.float32)

   # original
   start = time.time()
   for n in range(1000):
       for i in range(1000):
           y = np.log(1.0 + np.exp(x))
   print("original took %f s" %(time.time() - start))

   # Cython
   start = time.time()
   test1(x)
   print("Cython took %f s" %(time.time() - start))

   # Numba
   start = time.time()
   test2(x)
   print("Numba took %f s" %(time.time() - start))

   # numexpr
   start = time.time()
   for n in range(1000):
       for i in range(1000):
           y = ne.evaluate("log(1.0 + exp(x))").astype("float32")
   print("numexpr took %f s" %(time.time() - start))

実行結果

original took 30.707380 s
Cython took 30.750478 s
Numba took 153.180645 s
numexpr took 3.371960 s

よく分からないが非常に速くなった.
部分的な高速化であればnumexprで十分そう.

12
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?