LoginSignup
12
15

More than 5 years have passed since last update.

Numba使用を前提とした単純移動平均のPythonコードについて

Posted at

テクニカル指標で最も基本的な移動平均、そのなかでも単純移動平均(SMA)はただ平均を取るだけ、といってもSMA以外の多くのテクニカル指標の計算に使われています。実際、GitHubに掲載している30個あまりのテクニカル指標のうち、SMAを使っているものは4割にのぼります。

今回は、そのSMAに特化して、いくつかのPythonコードを比較してみたいと思います。

準備

Pythonのパッケージ一式をアップデートしたので、Pythonおよび、使用したパッケージのバージョンは以下の通りです。

  • Python 3.5.2
  • numpy 1.11.1
  • pandas 0.18.1
  • numba 0.26.0
  • scipy 0.17.1

まずは、
Pythonでランダムウォーク
を参考に100,000サンプルのランダムウォークを作っておきます。これをSMAの入力データとします。

import numpy as np
import pandas as pd
from numba import jit

dn = np.random.randint(2, size=100000)*2-1
gwalk = np.cumprod(np.exp(dn*0.01))*100

pandasのrolling,meanによる実装

SMAの一番簡単な実装は、pandasを使ったものです。Seriesのメソッドrolling,meanを使って簡単に書けます。

def SMA1(x, period):
    return pd.Series(x).rolling(period).mean()

共通の仕様として、引数に入力時系列とSMAの期間を入れます。期間の違いによる比較も行いため、period=20,200で計測しておきます。

%timeit y1_20 = SMA1(gwalk, 20)
%timeit y1_200 = SMA1(gwalk, 200)
100 loops, best of 3: 6.02 ms per loop
100 loops, best of 3: 6.01 ms per loop

pandasの場合、期間による実行速度の違いはないようです。

scipyのlfilterによる実装

Pythonで書いた移動平均の計算時間を比較してみた
を参考に、scipyのフィルタ関数lfilterを使って実装してみます。

from scipy.signal import lfilter
def SMA2(x, period):
    return lfilter(np.ones(period), 1, x)/period

同じように実行時間を計測してみます。

%timeit y2_20 = SMA2(gwalk, 20)
%timeit y2_200 = SMA2(gwalk, 200)
100 loops, best of 3: 5.53 ms per loop
100 loops, best of 3: 10.4 ms per loop

lfilterはSMA専用でなく汎用的なフィルタ関数なので、期間によって実行時間が変わるようです。期間が短ければpandasより速いですが、期間が長くなると遅くなります。

for文による実装(1)

SMAの計算式をfor文を使って直接書いてみます。当然、そのままだと遅くなることは目に見えているので、タイトルにもあるようにnumbaを使って高速化します。

@jit
def SMA3(x, period):
    y = np.zeros(len(x))
    for i in range(len(y)):
        for j in range(period):
            y[i] += x[i-j]
    return y/period
%timeit y3_20 = SMA3(gwalk, 20)
%timeit y3_200 = SMA3(gwalk, 200)
100 loops, best of 3: 3.07 ms per loop
10 loops, best of 3: 32.3 ms per loop

for文を使っていますが、numbaの高速化の効果があって期間が20の場合、これまでで最速です。ただ、期間に比例するので、200だと10倍遅くなってしまい、最遅となってしまいます。

for文による実装(2)

最後の実装は、SMAの特徴を利用した方法です。SMAはサンプルを単純に加算するだけなので、1サンプル前の計算結果を使って、古いサンプル値を引いて、新しいサンプル値を足す計算だけを行います。

@jit
def SMA4(x, period):
    y = np.empty(len(x))
    y[:period-1] = np.nan
    y[period-1] = np.sum(x[:period])
    for i in range(period, len(x)):
        y[i] = y[i-1]+x[i]-x[i-period]
    return y/period

サンプルが期間分揃うまで足していきますが、その後は3つのデータを加算するだけです。実行速度は以下のようになりました。

%timeit y4_20 = SMA4(gwalk, 20)
%timeit y4_200 = SMA4(gwalk, 200)
1 loop, best of 3: 727 µs per loop
1000 loops, best of 3: 780 µs per loop

これまでの実装のなかで最速の結果が出ました。さらに期間が長くなってもほとんど変わらない結果となりました。

以上のように、numbaによる高速化を前提とすれば、SMAに関しては、for文を使っても結構高速になることがわかりました。

12
15
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
15