LoginSignup
16
16

More than 3 years have passed since last update.

Cythonで高速化するにはどう書けばいいか

Posted at

はじめに:今回のお題

以下のような二次元配列を考えます。画像のピクセル情報と考えてください。1

In [3]: a
Out[3]:
array([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
       [ 8.,  9., 10., 11., 12., 13., 14., 15.],
       [16., 17., 18., 19., 20., 21., 22., 23.],
       [24., 25., 26., 27., 28., 29., 30., 31.],
       [32., 33., 34., 35., 36., 37., 38., 39.],
       [40., 41., 42., 43., 44., 45., 46., 47.]])

この配列から、ある座標(x, y)のn近傍(自分含む)を取り出して一次元に並べることを考えます(後続の処理で使いやすくするため)
例えば、(1, 1)の1近傍は以下のようになります2。なおNumPy配列的にはa[y, x]な点に注意。

In [5]: a[0:3, 0:3].reshape(-1)
Out[5]: array([ 0.,  1.,  2.,  8.,  9., 10., 16., 17., 18.])

この処理を画像サイズ分回すとくっそ遅かったので速くするにはどうすればいいかを考えたのが今回のお話となります。

NumPyでの高速化

一点だけでなく、縦横分回すと以下のようになります(pickup_loop1)。伏線込めてスライスを使わない超ベタループ(pickup_loop0)も掲載。

pickup.py
def pickup_loop1(a, n=0):
  height = a.shape[0]
  width  = a.shape[1]
  nheight = nwidth = 1 + n * 2
  b = np.empty((height - n * 2, width - n * 2, nheight * nwidth), np.double)
  for y in range(height - n * 2):
    for x in range(width - n * 2):
      b[y, x] = a[y : y + nheight, x : x + nwidth].reshape(-1)
  return b

def pickup_loop0(a, n=0):
  height = a.shape[0]
  width  = a.shape[1]
  nheight = nwidth = 1 + n * 2
  b = np.empty((height - n * 2, width - n * 2, nheight * nwidth), np.double)
  for y in range(height - n * 2):
    for x in range(width - n * 2):
      for j in range(nheight):
        for i in range(nwidth):
          b[y, x, j * nwidth + i] = a[y + j, x + i]
  return b

800x600の画像を想定したデータでn=1として処理時間を計測してみます。

In [6]: import pickup

In [7]: a = np.arange(600 * 800, dtype=np.double).reshape(600, 800)

In [8]: %timeit pickup.pickup_loop0(a, 1)
1.76 s ± 7.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [9]: %timeit pickup.pickup_loop1(a, 1)
760 ms ± 2.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

遅すぎて話になりません。これは本処理の前の初期化処理なのです。

Cythonの話の前に、「毎回reshapeしてるのが悪いんだよな」と思ったのでreshapeを最後にまとめてやるようにしてみました。

pickup.py
def pickup_loop2(a, n=0):
  height = a.shape[0]
  width  = a.shape[1]
  nheight = nwidth = 1 + n * 2
  b = np.empty((height - n * 2, width - n * 2, nheight, nwidth), np.double)
  for y in range(height - n * 2):
    for x in range(width - n * 2):
      b[y, x] = a[y : y + nheight, x : x + nwidth]
  return b.reshape(height - 2, width - 2, -1)

時間計測。速くはなったけどまだまだです。

In [10]: %timeit pickup.pickup_loop2(a, 1)
351 ms ± 1.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

「ださーい、NumPy使ってるのにfor文書いていいのは小学生までだよねー」という方については、Cython化の項を見る前にNG集をご覧ください。

Cython化

駄目なCythonの書き方

チュートリアルを参考にpickup_loop2をCython化してみます。紛らわしいですがpickup_loop_cy1という名前にします。変数への型付け、特にNumPy配列への型付けが重要と。3

pickup_cy.pyx
import numpy as np
cimport numpy as np

DTYPE = np.double
ctypedef np.double_t DTYPE_t

def pickup_loop_cy1(np.ndarray[DTYPE_t, ndim=2] a, int n=0):
  cdef int height, width
  cdef int nheight, nwidth
  cdef int y, x

  height = a.shape[0]
  width  = a.shape[1]
  nheight = nwidth = 1 + n * 2
  cdef np.ndarray[DTYPE_t, ndim=4] b = np.empty((height - n * 2, width - n * 2, nheight, nwidth), dtype=DTYPE)

  for y in range(height - n * 2):
    for x in range(width - n * 2):
      b[y, x] = a[y : y + nheight, x : x + nwidth]
  cdef np.ndarray[DTYPE_t, ndim=3] b2 = b.reshape(height - n * 2, width - n * 2, -1)
  return b2

時間を計ってみる。速くなってないやん。

In [11]: import pickup_cy

In [12]: %timeit pickup_cy.pickup_loop_cy1(a, 1)
332 ms ± 7.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

この理由は生成されたpickup_cy.cpickup_loop_cy1部分を確認するとわかります。

pickup_cy.c抜粋
      /* "pickup_cy.pyx":21
 *   for y in range(height - n * 2):
 *     for x in range(width - n * 2):
 *       b[y, x] = a[y : y + nheight, x : x + nwidth]             # <<<<<<<<<<<<<<
 *   cdef np.ndarray[DTYPE_t, ndim=3] b2 = b.reshape(height - n * 2, width - n * 2, -1)
 *   return b2
 */
      // 「y : y + nheight」に対応するコード
      __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_y); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 21, __pyx_L1_error)
      __Pyx_GOTREF(__pyx_t_5);
      __pyx_t_7 = __Pyx_PyInt_From_int((__pyx_v_y + __pyx_v_nheight)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 21, __pyx_L1_error)
      __Pyx_GOTREF(__pyx_t_7);
      __pyx_t_6 = PySlice_New(__pyx_t_5, __pyx_t_7, Py_None); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 21, __pyx_L1_error)
      __Pyx_GOTREF(__pyx_t_6);
      __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0;
      __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0;

      // 中略

      // 「b[y, x] =」に対応するコード
      __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_y); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 21, __pyx_L1_error)
      __Pyx_GOTREF(__pyx_t_5);
      __pyx_t_6 = __Pyx_PyInt_From_int(__pyx_v_x); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 21, __pyx_L1_error)
      __Pyx_GOTREF(__pyx_t_6);
      __pyx_t_7 = PyTuple_New(2); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 21, __pyx_L1_error)
      __Pyx_GOTREF(__pyx_t_7);
      __Pyx_GIVEREF(__pyx_t_5);
      PyTuple_SET_ITEM(__pyx_t_7, 0, __pyx_t_5);
      __Pyx_GIVEREF(__pyx_t_6);
      PyTuple_SET_ITEM(__pyx_t_7, 1, __pyx_t_6);
      __pyx_t_5 = 0;
      __pyx_t_6 = 0;
      if (unlikely(PyObject_SetItem(((PyObject *)__pyx_v_b), __pyx_t_7, __pyx_t_3) < 0)) __PYX_ERR(0, 21, __pyx_L1_error)
      __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0;
      __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;

Python APIを呼び出しています。これでは意味がない。

正しいCythonの書き方

SLICというアルゴリズムがあります。というか今回のお題の元ネタはSLICを実装してくれとの依頼でscikit-imageに用意されているslic関数が使えないか調査した際にコアの計算はCythonで書かれているということがわかったのですが、forが何重にもループしています。が、この関数今どきのマシンなら800x600画像を1秒ぐらいで処理できます。
NumPy脳なので「いやfor文自分で書くとかm9(^Д^)プギャー」と思ってたのですがどうやら逆にCythonではfor文をベタに回す(NumPyのスライス機能とか一切使わない)ようにすべきなのではないかということに気づき実装してみました。伏線しておいたpickup_loop0に対応するpickup_loop_cy2です。

pickup_cy.pyx
def pickup_loop_cy2(np.ndarray[DTYPE_t, ndim=2] a, int n=0):
  cdef int height, width
  cdef int nheight, nwidth
  cdef int y, x, j, i

  height = a.shape[0]
  width  = a.shape[1]
  nheight = nwidth = 1 + n * 2
  cdef np.ndarray[DTYPE_t, ndim=3] b = np.empty((height - n * 2, width - n * 2, nheight * nwidth), dtype=DTYPE)

  for y in range(height - n * 2):
    for x in range(width - n * 2):
      for j in range(nheight):
        for i in range(nwidth):
          b[y, x, j * nwidth + i] = a[y + j, x + i]
  return b

計測してみる。10倍以上速くなりました。

In [13]: %timeit pickup_cy.pickup_loop_cy2(a, 1)
24.4 ms ± 178 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

チュートリアルにあるように境界チェックとかを外してみます。

pickup_cy.pyx
@cython.boundscheck(False)
@cython.wraparound(False)
def pickup_loop_cy2a(np.ndarray[DTYPE_t, ndim=2] a, int n=0):
  # pickup_loop_cy2と同じコード

2倍、とは言えないけど速くなりました。

n [14]: %timeit pickup_cy.pickup_loop_cy2a(a, 1)
14.6 ms ± 63 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

別表記での検討

Cythonドキュメントのユーザーガイドの方を見るとCレベル配列アクセスの別の書き方がされています。データが連続する次元を指定すると速くなるとのことなのでやってみました。

pickup_cy.pyx
@cython.boundscheck(False)
@cython.wraparound(False)
def pickup_loop_cy3a(double[:, ::1] a, int n=0):
  cdef int height, width
  cdef int nheight, nwidth
  cdef int y, x, j, i

  height = a.shape[0]
  width  = a.shape[1]
  nheight = nwidth = 1 + n * 2
  cdef np.ndarray[DTYPE_t, ndim=3] b = np.empty((height - n * 2, width - n * 2, nheight * nwidth), dtype=DTYPE)
  cdef double[:, :, ::1] bv = b

  for y in range(height - n * 2):
    for x in range(width - n * 2):
      for j in range(nheight):
        for i in range(nwidth):
          bv[y, x, j * nwidth + i] = a[y + j, x + i]
  return b

計測。あまり変わりません。

In [15]: %timeit pickup_cy.pickup_loop_cy3a(a, 1)
13.8 ms ± 91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

なお生成されるCコードの違いは以下のようになります。pickup_loop_cy3aの方が速そうな気もしますが最適化されるとほとんど変わりはない速度になるようです。
__Pyx_BufPtrStrided3d等は第1引数を見るとわかりますが関数呼び出しではなくマクロです。

pickup_cy.c抜粋
// pickup_loop_cy2a
          /* "pickup_cy.pyx":58
 *       for j in range(nheight):
 *         for i in range(nwidth):
 *           b[y, x, j * nwidth + i] = a[y + j, x + i]             # <<<<<<<<<<<<<<
 *   return b
 * 
 */
          __pyx_t_19 = (__pyx_v_y + __pyx_v_j);
          __pyx_t_20 = (__pyx_v_x + __pyx_v_i);
          __pyx_t_21 = __pyx_v_y;
          __pyx_t_22 = __pyx_v_x;
          __pyx_t_23 = ((__pyx_v_j * __pyx_v_nwidth) + __pyx_v_i);
          *__Pyx_BufPtrStrided3d(__pyx_t_9pickup_cy_DTYPE_t *, __pyx_pybuffernd_b.rcbuffer->pybuffer.buf, __pyx_t_21, __pyx_pybuffernd_b.diminfo[0].strides, __pyx_t_22, __pyx_pybuffernd_b.diminfo[1].strides, __pyx_t_23, __pyx_pybuffernd_b.diminfo[2].strides) = (*__Pyx_BufPtrStrided2d(__pyx_t_9pickup_cy_DTYPE_t *, __pyx_pybuffernd_a.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_a.diminfo[0].strides, __pyx_t_20, __pyx_pybuffernd_a.diminfo[1].strides));

// pickup_loop_cy3a
          /* "pickup_cy.pyx":96
 *       for j in range(nheight):
 *         for i in range(nwidth):
 *           bv[y, x, j * nwidth + i] = a[y + j, x + i]             # <<<<<<<<<<<<<<
 *   return b
 * 
 */
          __pyx_t_20 = (__pyx_v_y + __pyx_v_j);
          __pyx_t_21 = (__pyx_v_x + __pyx_v_i);
          __pyx_t_22 = __pyx_v_y;
          __pyx_t_23 = __pyx_v_x;
          __pyx_t_24 = ((__pyx_v_j * __pyx_v_nwidth) + __pyx_v_i);
          *((double *) ( /* dim=2 */ ((char *) (((double *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_bv.data + __pyx_t_22 * __pyx_v_bv.strides[0]) ) + __pyx_t_23 * __pyx_v_bv.strides[1]) )) + __pyx_t_24)) )) = (*((double *) ( /* dim=1 */ ((char *) (((double *) ( /* dim=0 */ (__pyx_v_a.data + __pyx_t_20 * __pyx_v_a.strides[0]) )) + __pyx_t_21)) )));

並列化の検討

ユーザーガイドではさらにOpenMPを使って速くする方法が書かれていたので試してみました。

pickup_cy.pyx
@cython.boundscheck(False)
@cython.wraparound(False)
def pickup_loop_cy3b(double[:, ::1] a, int n=0):
  cdef int height, width
  cdef int nheight, nwidth
  cdef int y, x, j, i

  height = a.shape[0]
  width  = a.shape[1]
  nheight = nwidth = 1 + n * 2
  cdef np.ndarray[DTYPE_t, ndim=3] b = np.empty((height - n * 2, width - n * 2, nheight * nwidth), dtype=DTYPE)
  cdef double[:, :, ::1] bv = b

  for y in prange(height - n * 2, nogil=True):
    for x in range(width - n * 2):
      for j in range(nheight):
        for i in range(nwidth):
          bv[y, x, j * nwidth + i] = a[y + j, x + i]
  return b

計測。あまり変わりません。
データサイズを8000x6000にすると1.4倍ぐらい速くなるのですが現実的じゃないし、速度差が出ない理由はやってることがメモリアクセスだけだから(メモリアクセス速度の限界に引っかかるから)と思われます。

In [16]: %timeit pickup_cy.pickup_loop_cy3b(a, 1)
13.2 ms ± 118 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

まとめ

結果をまとめます。120倍(NumPyだけで努力したのに比べても24倍)速くなったのでスピード狂の私としても満足です。

関数名 説明 実行時間(ミリ秒) 実行時間比
pickup_loop0 NumPy forベタループ 1760.0 1.0
pickup_loop1 NumPy スライス使用 760.0 2.3
pickup_loop2 loop1のreshapeを最後にまとめる 351.0 5.0
pickup_loop_cy1 loop2を単純にCython化 332.0 5.3
pickup_loop_cy2 loop0のCython化 24.4 72.1
pickup_loop_cy2a 境界チェックなどをオフ 14.6 120.5

さて、記事中でも述べましたがCythonで高速化するためには、forループをベタに書くというのが今回得られた知見です。NumPy(自分でfor文書かない)で頑張ったけどまだ性能が出ないからCython使うという場合には気をつけるべきところでしょう。

NG集

NumPyのみでfor文なしを実現してみる

どうにかNumPyだけでできないかと考えて編み出したのが以下のコードになります。

pickup.py
def pickup_numpy(a, n=0):
  height = a.shape[0]
  width  = a.shape[1]
  nheight = nwidth = 1 + n * 2
  X, Y = np.meshgrid(np.arange(width - n * 2), np.arange(height - n * 2))
  I = np.stack((X.reshape(-1), Y.reshape(-1)), axis=-1)
  b = np.apply_along_axis(lambda i: a[i[1] : i[1] + nheight, i[0] : i[0] + nwidth], 1, I)
  return b.reshape(height - n * 2, width - n * 2, -1)

しかしこれは速くありません。

In [17]: %timeit pickup.pickup_numpy(a, 1)
1.28 s ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

速くないのは結果の配列を作るのにPython関数呼び出しを行っているためと思われます。4

typo注意

pickup_loop_cy2ですが初めはn=1をハードコーディングしてて後からnheight, nwidthを使うように書き直しました。

pickup_cy.pyx
def pickup_loop_cy2(np.ndarray[DTYPE_t, ndim=2] a, int n=0):
  cdef int height, width
  cdef int nheight, cwidth
  cdef int y, x, j, i

  height = a.shape[0]
  width  = a.shape[1]
  nheight = nwidth = 1 + n * 2
  cdef np.ndarray[DTYPE_t, ndim=3] b = np.empty((height - n * 2, width - n * 2, nheight * nwidth), dtype=DTYPE)

  for y in range(height - n * 2):
    for x in range(width - n * 2):
      for j in range(nheight):
        for i in range(nwidth):
          b[y, x, j * nwidth + i] = a[y + j, x + i]
  return b

計測してみると性能が出ない。なんで?変数使うといけないの??といろいろ調べた結果、

pickup_cy.pyx
def pickup_loop_cy2(np.ndarray[DTYPE_t, ndim=2] a, int n=0):
  cdef int height, width
  cdef int nheight, cwidth #←おい

もうお分かりだろう。
誰もnwidthを宣言してないのである!

型付けされてない変数はC的に処理されるのではなくPython的に処理されます。
というわけで、配列アクセスしているところも以下のようにPython API呼び出しがされてました・・・

pickup_cy.c抜粋
          /* "pickup_cy.pyx":39
 *       for j in range(nheight):
 *         for i in range(nwidth):
 *           b[y, x, j * nwidth + i] = a[y + j, x + i]             # <<<<<<<<<<<<<<
 *   return b
 * 
 */
          // 右辺
          __pyx_t_19 = (__pyx_v_y + __pyx_v_j);
          __pyx_t_20 = (__pyx_v_x + __pyx_v_i);
          __pyx_t_21 = -1;
          if (__pyx_t_19 < 0) {
            __pyx_t_19 += __pyx_pybuffernd_a.diminfo[0].shape;
            if (unlikely(__pyx_t_19 < 0)) __pyx_t_21 = 0;
          } else if (unlikely(__pyx_t_19 >= __pyx_pybuffernd_a.diminfo[0].shape)) __pyx_t_21 = 0;
          if (__pyx_t_20 < 0) {
            __pyx_t_20 += __pyx_pybuffernd_a.diminfo[1].shape;
            if (unlikely(__pyx_t_20 < 0)) __pyx_t_21 = 1;
          } else if (unlikely(__pyx_t_20 >= __pyx_pybuffernd_a.diminfo[1].shape)) __pyx_t_21 = 1;
          if (unlikely(__pyx_t_21 != -1)) {
            __Pyx_RaiseBufferIndexError(__pyx_t_21);
            __PYX_ERR(0, 39, __pyx_L1_error)
          }
          __pyx_t_4 = PyFloat_FromDouble((*__Pyx_BufPtrStrided2d(__pyx_t_9pickup_cy_DTYPE_t *, __pyx_pybuffernd_a.rcbuffer->pybuffer.buf, __pyx_t_19, __pyx_pybuffernd_a.diminfo[0].strides, __pyx_t_20, __pyx_pybuffernd_a.diminfo[1].strides))); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 39, __pyx_L1_error)

          // 左辺
          __Pyx_GOTREF(__pyx_t_4);
          __pyx_t_5 = __Pyx_PyInt_From_int(__pyx_v_y); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 39, __pyx_L1_error)
          __Pyx_GOTREF(__pyx_t_5);
          __pyx_t_6 = __Pyx_PyInt_From_int(__pyx_v_x); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 39, __pyx_L1_error)
          __Pyx_GOTREF(__pyx_t_6);
          __pyx_t_3 = __Pyx_PyInt_From_int(__pyx_v_j); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 39, __pyx_L1_error)
          __Pyx_GOTREF(__pyx_t_3);
          __pyx_t_2 = PyNumber_Multiply(__pyx_t_3, __pyx_v_nwidth); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 39, __pyx_L1_error)
          __Pyx_GOTREF(__pyx_t_2);
          __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
          __pyx_t_3 = __Pyx_PyInt_From_int(__pyx_v_i); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 39, __pyx_L1_error)
          __Pyx_GOTREF(__pyx_t_3);
          __pyx_t_22 = PyNumber_Add(__pyx_t_2, __pyx_t_3); if (unlikely(!__pyx_t_22)) __PYX_ERR(0, 39, __pyx_L1_error)
          __Pyx_GOTREF(__pyx_t_22);
          __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
          __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
          __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 39, __pyx_L1_error)
          __Pyx_GOTREF(__pyx_t_3);
          __Pyx_GIVEREF(__pyx_t_5);
          PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_5);
          __Pyx_GIVEREF(__pyx_t_6);
          PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_6);
          __Pyx_GIVEREF(__pyx_t_22);
          PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_22);
          __pyx_t_5 = 0;
          __pyx_t_6 = 0;
          __pyx_t_22 = 0;
          if (unlikely(PyObject_SetItem(((PyObject *)__pyx_v_b), __pyx_t_3, __pyx_t_4) < 0)) __PYX_ERR(0, 39, __pyx_L1_error)
          __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
          __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;

  1. お題の処理的には値はなんでもいいのでarangeとreshapeでテストデータを作っています。またdoubleなのは元ネタがRGBではなくLabで処理しているためです。 

  2. n近傍で取ると画像サイズが2n小さくなってしまうため実際にはnp.padで広げてから処理していますが本筋ではないので割愛。 

  3. importとcimportで同じ名前を付けているのが気持ち悪いですがチュートリアルでもこう書いてあるので踏襲します。 

  4. 返される型がわからないので事前に配列領域を確保できないことも遅い原因、と思いましたが入力配列の長さはわかるし出力配列の型は一つ目の呼び出し結果からわかるのでこれについては動的確保が何回もされることはないと思われます。 

16
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
16