#ステートフルRNNとは?
ステートレス(デフォルト)
・各バッチ後に状態をリセットする
ステートフル
・バッチ間で状態を維持する
・1バッチに対して計算された隠れ状態を次のバッチの初期隠れ状態として使う
コードの違い
①モデル
②学習
#①モデル
##モデル(ステートレス)
model = Sequential()
model.add(LSTM(HIDDEN_SIZE, input_shape=(NUM_TIMESTEPS, 1), return_sequences=False))
model.add(Dense(1))
##モデル(ステートフル)
model = Sequential()
model.add(LSTM(HIDDEN_SIZE, stateful=True,
#バッチサイズもあらかじめ指定しておく
batch_input_shape=(BATCH_SIZE, NUM_TIMESTEPS, 1),
return_sequences=False))
model.add(Dense(1))
#②学習
##学習(ステートレス)
model.fit(Xtrain, Ytrain, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE,
validation_data=(Xtest, Ytest),
shuffle=False)
##学習(ステートフル)
# 訓練データ・テストデータのサイズをバッチサイズの倍数にする
train_size = (Xtrain.shape[0] // BATCH_SIZE) * BATCH_SIZE
test_size = (Xtest.shape[0] // BATCH_SIZE) * BATCH_SIZE
Xtrain, Ytrain = Xtrain[0:train_size], Ytrain[0:train_size]
Xtest, Ytest = Xtest[0:test_size], Ytest[0:test_size]
print(Xtrain.shape, Xtest.shape, Ytrain.shape, Ytest.shape)
for i in range(NUM_EPOCHS):
print("Epoch {:d}/{:d}".format(i+1, NUM_EPOCHS))
model.fit(Xtrain, Ytrain, batch_size=BATCH_SIZE, epochs=1,
validation_data=(Xtest, Ytest),
# ステートフルで効果的に学習をするためにシャッフルをしない
shuffle=False)
# 1エポックごとに状態をリセット
model.reset_states()