kerasのRNN系APIのGRUの引数return_state, return_sequencesについて
大雑把に書きました。
環境
python3
使用する擬似データ
B = 1 #バッチサイズ
T = 10 #時系列長
N = 1000 #特徴量
data = np.random.randn(B, T, N)
使用するRNN系インターフェース
tf.keras.layers.GRU
return_state=True, return_sequences=True
赤丸 : return_sequencesをTrue時
緑丸 : return_statesをTrue時
gru = tf.keras.layers.GRU(256, return_state=True, return_sequences=True)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs, states = gru(data)
print("赤丸:", outputs.shape)
print("緑丸:", states.shape)
赤丸: (1, 10, 256)
緑丸: (1, 256)
return_state=True, return_sequences=False
赤丸 : return_sequencesをFalse時
緑丸 : return_statesをTrue時
gru = tf.keras.layers.GRU(256, return_state=True, return_sequences=False)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs, states = gru(data)
print("赤丸:", outputs.shape)
print("緑丸:", states.shape)
赤丸: (1, 256)
緑丸: (1, 256)
return_state=False, return_sequences=True
gru = tf.keras.layers.GRU(256, return_state=False, return_sequences=True)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs = gru(data)
print("赤丸:", outputs.shape)
print("緑丸なし")
赤丸: (1, 10, 256)
緑丸なし
return_state=False, return_sequences=False
gru = tf.keras.layers.GRU(256, return_state=False, return_sequences=False)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs = gru(data)
print("赤丸:", outputs.shape)
print("緑丸なし")
赤丸: (1, 256)
緑丸なし