LoginSignup
0
4

More than 5 years have passed since last update.

ステートフルRNN

Posted at

ステートフルRNNとは?

ステートレス(デフォルト)
・各バッチ後に状態をリセットする
ステートフル
・バッチ間で状態を維持する
・1バッチに対して計算された隠れ状態を次のバッチの初期隠れ状態として使う

コードの違い
①モデル
②学習

①モデル

モデル(ステートレス)

model = Sequential()
model.add(LSTM(HIDDEN_SIZE, input_shape=(NUM_TIMESTEPS, 1), return_sequences=False))
model.add(Dense(1))

モデル(ステートフル)

model = Sequential()
model.add(LSTM(HIDDEN_SIZE, stateful=True,
          #バッチサイズもあらかじめ指定しておく
          batch_input_shape=(BATCH_SIZE, NUM_TIMESTEPS, 1),
          return_sequences=False))
model.add(Dense(1))

②学習

学習(ステートレス)

model.fit(Xtrain, Ytrain, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE,
          validation_data=(Xtest, Ytest),
          shuffle=False)

学習(ステートフル)

# 訓練データ・テストデータのサイズをバッチサイズの倍数にする
train_size = (Xtrain.shape[0] // BATCH_SIZE) * BATCH_SIZE
test_size = (Xtest.shape[0] // BATCH_SIZE) * BATCH_SIZE
Xtrain, Ytrain = Xtrain[0:train_size], Ytrain[0:train_size]
Xtest, Ytest = Xtest[0:test_size], Ytest[0:test_size]
print(Xtrain.shape, Xtest.shape, Ytrain.shape, Ytest.shape)
for i in range(NUM_EPOCHS):
    print("Epoch {:d}/{:d}".format(i+1, NUM_EPOCHS))
    model.fit(Xtrain, Ytrain, batch_size=BATCH_SIZE, epochs=1,
              validation_data=(Xtest, Ytest),
              # ステートフルで効果的に学習をするためにシャッフルをしない
              shuffle=False)
    # 1エポックごとに状態をリセット
    model.reset_states()
0
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
4