LoginSignup
5
8

More than 5 years have passed since last update.

KerasのRNNを触ってみる②

Last updated at Posted at 2017-10-18

はじめに

この記事は、前回の続編です。
導入部分は前回を参照いただければと思います。

前回はSimpleRNNを
今回は、LSTMを使ってモデルを作ります。

1)モデル構築

学習データの作成部分は前回と同じなので割愛します。
今回は、LSTMクラスを使うので、モデル構築部分は下記のようになります。
(SimpleRNNがLSTMになっただけ)

model = Sequential()
model.add(InputLayer(batch_input_shape=(None, maxlen, in_out_dims)))
model.add(
    LSTM(units=hidden_dims, return_sequences=False))
model.add(Dense(in_out_dims))
model.add(Activation("linear"))
model.compile(loss="mean_squared_error", optimizer="rmsprop")

2)学習

ここも前回と同じです。
LSTMはSimpleRNNよりパラメータが多いので、時間がかかります。。

callbacks = []
# Early-stopping
callbacks.append(EarlyStopping(patience=0, verbose=1))
# CSVLogger
callbacks.append(CSVLogger("LSTM_history.csv"))
# fitting
history = model.fit(X_train, y_train, batch_size=100, epochs=100, validation_split=0.1, callbacks=callbacks)

3)結果の確認

誤差がSimpleRNNの時よりも減っています。(3.29 → 0.87)
image.png

*61回で打ち切り

4)教師データとの比較

前回と同様に訓練データの値と比較してみます。
(SimpleRNNも再掲します。)

precdict(SimpleRNN) precdict(LSTM) true
[17.03078651] [ 20.04093552] [ 22]
[14.13138199] [14.27271557] [14]
[ 11.11346722] [11.31826878] [11]
[7.76618958] [7.15918827] [7]
[6.90901232] [5.99405384] [6]
[ -1.77918458] [0.28667003] [0]

一行目の値が22にグッと近づきました。(まだ2離れてますが。。)
その他もちょっと誤差が小さくなっています。

とりあえず、これでLSTMの基本の確認はこれで良しとします。

5
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
8