78
73

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PythonコードをNumbaで高速化したときのメモ

Posted at

#はじめに
Pythonで書いた移動平均の計算時間を比較してみた
で、for文を使った移動平均(LWMA)が遅くて使い物にならないことがわかったのですが、MetaTraderのテクニカル指標のなかにはfor文使わないと書けないものもあって、それであきらめるわけにはいかないので、高速化に挑戦してみました。

とりあえず高速化にはCythonがあることは知っていたのですが、コードを書き換えなくてはいけないということだったので、ほかを調べてみたところ、Numbaというのがありました。今回はNumbaを試したときのメモです。

#for文を使ったとても遅いコード

import numpy as np
import pandas as pd
dataM1 = pd.read_csv('DAT_ASCII_EURUSD_M1_2015.csv', sep=';',
                     names=('Time','Open','High','Low','Close', ''),
                     index_col='Time', parse_dates=True)

def LWMA(s, ma_period):
    y = pd.Series(0.0, index=s.index)
    for i in range(len(y)):
        if i<ma_period-1: y[i] = 'NaN'
        else:
            y[i] = 0
            for j in range(ma_period):
                y[i] += s[i-j]*(ma_period-j)
            y[i] /= ma_period*(ma_period+1)/2
    return y

%timeit MA = LWMA(dataM1['Close'], 10)
1 loop, best of 3: 3min 14s per loop

前回の記事と同じく3分を超えてしまいます。

#Numbaを使ってみる

NumbaはAnacondaに入っているようなので、単にimportして、@numba.jitを追加するだけです。

import numba
@numba.jit
def LWMA(s, ma_period):
    y = pd.Series(0.0, index=s.index)
    for i in range(len(y)):
        if i<ma_period-1: y[i] = 'NaN'
        else:
            y[i] = 0
            for j in range(ma_period):
                y[i] += s[i-j]*(ma_period-j)
            y[i] /= ma_period*(ma_period+1)/2
    return y

%timeit MA = LWMA(dataM1['Close'], 10)
1 loop, best of 3: 3min 14s per loop

おや、同じ結果。全然効果がありません。Numbaっていうのは名前からしてNumpy専用なのかな?

#pandasをnumpyに変えてみる
入力データがpandasのSeries型だったのをnumpyのarray型に変えてみました。

@numba.jit
def LWMA(s, ma_period):
    y = np.zeros(len(s))
    for i in range(len(y)):
        if i<ma_period-1: y[i] = 'NaN'
        else:
            y[i] = 0
            for j in range(ma_period):
                y[i] += s[i-j]*(ma_period-j)
            y[i] /= ma_period*(ma_period+1)/2
    return y

%timeit MA = LWMA(dataM1['Close'].values, 10)
1 loop, best of 3: 2.11 s per loop

こんどは速くなりました。90倍くらい。ただ、scipyの数ミリ秒に比べるとまだまだ遅い。

#if文を除いてみる
コンパイルするにしても、コードはシンプルな方がいいに決まっているので、if文を除いてみました。実はこのif文、あってもなくてもいいコードだったのです。

@numba.jit
def LWMA(s, ma_period):
    y = np.zeros(len(s))
    for i in range(len(y)):
        for j in range(ma_period):
            y[i] += s[i-j]*(ma_period-j)
        y[i] /= ma_period*(ma_period+1)/2
    return y

%timeit MA = LWMA(dataM1['Close'].values, 10)
100 loops, best of 3: 5.73 ms per loop

出ました!ミリ秒。for文があってもscipy並みに高速化することができました。やればできるじゃないか、Python。

#まとめ

for文を使って遅くなったコードもNumbaを使うことで高速化できました。ただし、効果があるのはnumpyに対してで、pandasに対しては全く効果はありませんでした。

#参考記事
NumPyでfor文を使うと遅いと思ったがそんな事はなかったぜ

78
73
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
78
73

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?