LoginSignup
2
3

More than 3 years have passed since last update.

kerasのRNN系APIの引数return_state, return_sequencesについて

Posted at

kerasのRNN系APIのGRUの引数return_state, return_sequencesについて

大雑把に書きました。

環境
python3

使用する擬似データ

B = 1  #バッチサイズ
T = 10 #時系列長
N = 1000 #特徴量
data = np.random.randn(B, T, N)

使用するRNN系インターフェース

tf.keras.layers.GRU

return_state=True, return_sequences=True

image.png

赤丸 : return_sequencesをTrue時
緑丸 : return_statesをTrue時

gru = tf.keras.layers.GRU(256, return_state=True, return_sequences=True)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs, states = gru(data)
print("赤丸:", outputs.shape)
print("緑丸:", states.shape)
赤丸: (1, 10, 256)
緑丸: (1, 256)

return_state=True, return_sequences=False

image.png

赤丸 : return_sequencesをFalse時
緑丸 : return_statesをTrue時

gru = tf.keras.layers.GRU(256, return_state=True, return_sequences=False)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs, states = gru(data)
print("赤丸:", outputs.shape)
print("緑丸:", states.shape)
赤丸: (1, 256)
緑丸: (1, 256)

return_state=False, return_sequences=True

image.png

gru = tf.keras.layers.GRU(256, return_state=False, return_sequences=True)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs = gru(data)
print("赤丸:", outputs.shape)
print("緑丸なし")
赤丸: (1, 10, 256)
緑丸なし

return_state=False, return_sequences=False

image.png

gru = tf.keras.layers.GRU(256, return_state=False, return_sequences=False)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs = gru(data)
print("赤丸:", outputs.shape)
print("緑丸なし")
赤丸: (1, 256)
緑丸なし
2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3