10
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

tf.nn.dynamic_rnnでDynamic Recurrent Neural Network

Posted at

TensorflowでDynamic Recurrent Neural Network(入力の系列長が可変のRNN)を計算する方法としてaymericdamien/TensorFlow-Examplesで紹介されているtf.nn.rnnを利用する方法がありますが、tf.nn.dynamic_rnnを用いるとより簡単に実装できました。
(TensorFlow v0.10.0rc0で動作を確認しています)


ポイントは入力をゼロ埋めした上で、tf.nn.dynamic_rnnの引数sequence_lengthにゼロ埋めする前の系列の長さを渡すことです。

まず、placeholderとしてinputの他に、各系列の長さを格納するsequence_lengthを用意します。

input = tf.placeholder(tf.float32, [None, seq_max_len, nin])
sequence_length = tf.placeholder(tf.int32, [None])

それからいつものようにセルを準備して(このセルはGRUじゃなくても何でも良いです)

cell = tf.nn.rnn_cell.GRUCell(n_hidden)
initial_state = cell.zero_state(batch_size, dtype=tf.float32)

出力を得ます(sequence_lengthを引数に与える)。

output, _ = tf.nn.dynamic_rnn(
    cell, input, initial_state=initial_state,
    sequence_length=sequence_length)

最後に各系列の最後の出力を取り出す必要があります。tf.gatherは1-Dのインデックスしかサポートしていない(逆にtf.gather_ndは勾配計算をサポートしていない… #206)ので、一度tf.reshapeを使ってoutputをflattenしてからtf.gatherしています。

index = tf.range(0, batch_size) * seq_max_len + (sequence_length - 1)
output = tf.gather(tf.reshape(output, [-1, n_hidden]), index)

擬似的なデータを使った完全なプログラムはここに置いてあります。

…しかし早くtheanoみたいにnumpyライクなインデックシングしたいですねぇ。

10
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?