1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Rのkerasで時系列予測(LSTM)

Last updated at Posted at 2021-02-16

#すること
R言語でkerasを使って時系列予測をしてみました。
株価推移をサンプルにした、LSTMによる時系列予測です。
前日までの終値をもとに、翌日の終値を予測してみます。

※Windowsのユーザ名が半角でないと、途中でうまくインストールできないので注意
#データの前処理
##株価取得

library(BatchGetSymbols)
tickers=c('%5EIBEX')
firstdate=Sys.Date()-360*15
lastdate=Sys.Date()

stts=BatchGetSymbols(tickers=tickers, first.date=firstdate, last.date=lastdate,
                     cache.folder=file.path(tempdir(),'BGS_Cache'))

#終値と出来高のデータフレームに書き換え(indexにはとりあえず日付をセット)
stts=data.frame(index=stts$df.tickers$ref.date,price=stts$df.tickers$price.close,vol=stts$df.tickers$volume)
#N/Aを除去
stts=stts[complete.cases(stts),]
#現在からさかのぼって3050件取得
stts=stts[-seq(nrow(stts)-3050),]
#index連番に置き換え
stts$index=seq(nrow(stts))

#縦軸終値で散布図を作成(出来高は色)
library(plotly)
plot_ly(stts,x=~index,y=~price,type="scatter",mode="markers",color=~vol)

image.png

##標準化
基本的にニューラルネットワークに投入するデータは標準化が必須です。

#各値について平均との差を求めて標準偏差で割ります
mprice=mean(stts$price)
sdprice=sd(stts$price)
stts$price=(stts$price-mprice)/sdprice
#出来高も同様に
mvol=mean(stts$vol)
sdvol=sd(stts$vol)
stts$vol=(stts$vol-mvol)/sdvol

##LSTM用にデータを成型
今回は、1日目の終値(P1)から50日目の終値(P50)を使って51日目の終値(P51)を予測することを考えます。
image.png

1日ずつずらしたデータで訓練データをたくさん生成します
image.png
今回は2000個作りました。
これで、2000 x 50 x 1 と 2000 x 1 からなる訓練データが作成できました。

次に、予測に出来高の値(Vn)も使いたいので、終値同様にデータを作成し、先の終値のデータと結合します。
以下のような 2000 x 50 x 2 と 2000 x 1 からなる訓練データを作成します。
image.png

これをコードで書いてみましょう。

#予測に使うデータ数
window=50

data.x.price=NULL
data.x.vol=NULL
data.y=NULL
for(i in c(1:2000)){
  #1日ずつずらした50日分の行を足して 2000 x 50 のmatrixを作る
  data.x.price=rbind(data.x.price, stts$price[i:(window+i-1)])
  data.x.vol=rbind(data.x.vol,stts$vol[i:(window+i-1)])
  #プラス1日目の終値で 2000 x 1 のmatrixを作る
  data.y=rbind(data.y,stts$price[window+i])
}
#終値の 2000 x 50 のmatrixの横に、出来高の 2000 x 50 のmatrixをくっつけて 2000 x 100 のmatrixを作り
#それを 2000 x 50 x 2に成型する 
data.x=array(cbind(data.x.price,data.x.vol),dim=c(2000,window,2))

次に訓練用データ同様にテスト用データを作成します。

test.x.price=NULL
test.x.vol=NULL
test.y=NULL
for(i in c(2001:3000)){
  #1日ずつずらした50日分の行を足して 1000 x 50 のmatrixを作る
  test.x.price=rbind(test.x.price, stts$price[i:(window+i-1)])
  test.x.vol=rbind(test.x.vol,stts$vol[i:(window+i-1)])
  #プラス1日目の終値で 1000 x 1 のmatrixを作る
  test.y=rbind(test.y,stts$price[window+i])
}
#終値の 1000 x 50 のmatrixの横に、出来高の 1000 x 50 のmatrixをくっつけて 1000 x 100 のmatrixを作り
#それを 1000 x 50 x 2に成型する 
test.x=array(cbind(test.x.price,test.x.vol),dim=c(1000,window,2))

#LSTMモデルの作成
kerasの時系列モデルを読み込みます。R起動後の1回目はこれに案外時間がかかります。
初めて実行するときは、ここでMinicondaのインストールを要求されますので、促されるままにインストールしましょう。
※Windowsのユーザ名が半角でないと、ここでつまづきます。

library(keras)
model=keras_model_sequential()

次にLSTMモデルの細かい定義を設定しコンパイルします。

#損失関数を評価する訓練数の単位
#今回の場合2000/10=200回損失係数が評価される
batchsize=10

#LSTMモデルの設定
model %>%
   layer_lstm(units = 100, #ノードの数?
              input_shape=c(window, 2), #説明変数が 50 x 2(50日分の終値と出来高)ということ
              batch_size=batchsize,
              return_sequences=TRUE,
              stateful=TRUE) %>%
   layer_dropout(rate=0.5) %>%
   layer_lstm(units=50,
              return_sequences=FALSE,
              stateful=TRUE) %>%
   layer_dropout(rate=0.5) %>%
   layer_dense(units=1)

#モデルのコンパイル
model %>%
   compile(loss="mean_squared_error",optimizer=optimizer_adam(),metric="accuracy")

作成したモデルがちゃんとできているか確認します。

#作成したモデルの確認
model

#モデルのトレーニング
訓練データを使ってトレーニングします。
同じデータで何回もトレーニングさせることで精度が向上します。
同じデータで訓練させる回数をエポック数と言います。

model %>%
   fit(x=data.x,
       y=data.y,
       batch_size=batchsize,
       epochs = 15,
       vervose = 1, #1エポック毎に損失係数のグラフをプロットします
       shuffle=FALSE)

#予測
予測はpredict関数を使います

pred_out=model %>% predict(test.x, batch_size=batchsize)

プロットしてみましょう

plot_ly(stts, x=~index,y=~price,type="scatter",mode="markers",color=~vol) %>%
   add_trace(y=c(rep(NA,2050),pred_out[1:1000,]),x=stts$index,name="LSTM prediction", mode="lines")

IMG_20210217_193220.jpg

また、予測vs実績のグラフ(y=xになると予測精度が高い)も書いてみましょう。

plot(x=test.y,y=pred_out)

IMG_20210217_193016.jpg

###参考にしたサイト
R-BLOGGERS R news and tutorials contoributed by hundreds of R bloggers
※このサイトのデータの作り方が間違えていた(予測値もむちゃくちゃだった)ので、その修正版の記事として今回この記事を書きました。

1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?