TensorflowでDynamic Recurrent Neural Network(入力の系列長が可変のRNN)を計算する方法としてaymericdamien/TensorFlow-Examplesで紹介されているtf.nn.rnn
を利用する方法がありますが、tf.nn.dynamic_rnn
を用いるとより簡単に実装できました。
(TensorFlow v0.10.0rc0で動作を確認しています)
ポイントは入力をゼロ埋めした上で、tf.nn.dynamic_rnn
の引数sequence_length
にゼロ埋めする前の系列の長さを渡すことです。
まず、placeholderとしてinput
の他に、各系列の長さを格納するsequence_length
を用意します。
input = tf.placeholder(tf.float32, [None, seq_max_len, nin])
sequence_length = tf.placeholder(tf.int32, [None])
それからいつものようにセルを準備して(このセルはGRUじゃなくても何でも良いです)
cell = tf.nn.rnn_cell.GRUCell(n_hidden)
initial_state = cell.zero_state(batch_size, dtype=tf.float32)
出力を得ます(sequence_lengthを引数に与える)。
output, _ = tf.nn.dynamic_rnn(
cell, input, initial_state=initial_state,
sequence_length=sequence_length)
最後に各系列の最後の出力を取り出す必要があります。tf.gather
は1-Dのインデックスしかサポートしていない(逆にtf.gather_nd
は勾配計算をサポートしていない… #206)ので、一度tf.reshape
を使ってoutput
をflattenしてからtf.gather
しています。
index = tf.range(0, batch_size) * seq_max_len + (sequence_length - 1)
output = tf.gather(tf.reshape(output, [-1, n_hidden]), index)
擬似的なデータを使った完全なプログラムはここに置いてあります。
…しかし早くtheanoみたいにnumpyライクなインデックシングしたいですねぇ。