TensorflowでRNNを実装する時に、tf.nn.static_rnnとtf.nn.dynamic_rnnのどっちを使えば良いか、その違いは何かで悩んだことがあります。

結論から言うと、Tensorflow 1.7では tf.nn.static_rnn も tf.nn.dynamic_rnn もコアの処理はほぼ同じなので、便利な dynamic_rnn を使った方が良いってことです。
下記からはその詳細になります。

TensorflowでRNNを実装する時、下記のような実装になります。

# tensorflow/python/ops/rnn.py の def dynamic_rnnのコメントから引用
# Tensorflow version 1.7
# BasicRNNCellの作成
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

# initial state の作成
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

# ここが問題
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)

ここで悩むのは最後のtf.nn.dynamic_rnnのところで、同じ機能を持つ tf.nn.static_rnn の存在です。両方とも同じ機能を持ちますが、inputの形が異なるので何を最終的に使えば良いか悩んでしまいます。

でも、中身を見ると、両方ともコアのところはあまり差がないです。
すごくざっくり説明をすると、static_rnnはinputの前処理をユーザーに任して決められたグラフを書く反面、 dynamic_rnn は前処理などを全部内部で処理して、それを元にグラフを書く違いがあります。つまり、static_rnnはユーザーが最初にグラフの形に合わせてinputを作成する必要がありますが、dynamic_rnnは内部で inputデータ を分析して処理してくれるので、ユーザーがinputをグラフの形に合わせなくても良いことですね。
実際のコードを全部見せるには長すぎるので、そのコアの処理部分だけ紹介します。

まずは static_rnn

@tf_export("nn.static_rnn")
def static_rnn(cell,
               inputs,
               initial_state=None,
               dtype=None,
               sequence_length=None,
               scope=None):
    ...
    # いろんな前処理
    ...

    # ここがコアの処理部分
    for time, input_ in enumerate(inputs):
      if time > 0:
        varscope.reuse_variables()
      # pylint: disable=cell-var-from-loop
      call_cell = lambda: cell(input_, state)
      # pylint: enable=cell-var-from-loop
      if sequence_length is not None:
        (output, state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            min_sequence_length=min_sequence_length,
            max_sequence_length=max_sequence_length,
            zero_output=zero_output,
            state=state,
            call_cell=call_cell,
            state_size=cell.state_size)
      else:
        (output, state) = call_cell()

    ....

ご覧の通り、inputsを_rnn_stepかcall_cellにループさせてますね。

次は dynamic_rnn です。
dynamic_rnnの場合、コア処理は _dynamic_rnn_loop で行われてます。

@tf_export("nn.dynamic_rnn")
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
                dtype=None, parallel_iterations=None, swap_memory=False,
                time_major=False, scope=None):
  ... 
  # いろんな前処理
  ...
  # 前処理した後、_dynamic_rnn_loopを呼ぶ
  (outputs, final_state) = _dynamic_rnn_loop(
        cell,
        inputs,
        state,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory,
        sequence_length=sequence_length,
        dtype=dtype)
  ...

# コア処理はここで担当
def _dynamic_rnn_loop(cell,
                      inputs,
                      initial_state,
                      parallel_iterations,
                      swap_memory,
                      sequence_length=None,
                      dtype=None):
  ...
  # いろんな前処理
  ...
  # _rnn_stepをラッピングするinner function
  def _time_step(time, output_ta_t, state):
    ... 
    # いろんな処理
    ...

    call_cell = lambda: cell(input_t, state)

    if sequence_length is not None:
      (output, new_state) = _rnn_step(
          time=time,
          sequence_length=sequence_length,
          min_sequence_length=min_sequence_length,
          max_sequence_length=max_sequence_length,
          zero_output=zero_output,
          state=state,
          call_cell=call_cell,
          state_size=state_size,
          skip_conditionals=True)
    else:
      (output, new_state) = call_cell()
  ...

  # _time_step(_rnn_step)をループさせる。
  _, output_final_ta, final_state = control_flow_ops.while_loop(
      cond=lambda time, *_: time < loop_bound,
      body=_time_step,
      loop_vars=(time, output_ta, state),
      parallel_iterations=parallel_iterations,
      maximum_iterations=time_steps,
      swap_memory=swap_memory)
  ...

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py
ご覧の通り、static_rnnと違って処理が複雑になっていますが、コアのところは両方とも _rnn_stepを使って、似たような処理を行っています。なので、staticを使おうが、dynamicを使おうがそこまで違いはないと思いました。

でもやっぱり input の処理は dynamic_rnnの方が便利なので私は dynamic_rnnを使います。

Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.