6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Tensorflowのstatic_rnnとdynamic_rnnの解析

Last updated at Posted at 2018-03-31

TensorflowでRNNを実装する時に、tf.nn.static_rnnとtf.nn.dynamic_rnnのどっちを使えば良いか、その違いは何かで悩んだことがあります。

結論から言うと、Tensorflow 1.7では tf.nn.static_rnn も tf.nn.dynamic_rnn もコアの処理はほぼ同じなので、便利な dynamic_rnn を使った方が良いってことです。
下記からはその詳細になります。

TensorflowでRNNを実装する時、下記のような実装になります。

# tensorflow/python/ops/rnn.py の def dynamic_rnnのコメントから引用
# Tensorflow version 1.7
# BasicRNNCellの作成
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

# initial state の作成
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

# ここが問題
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)

ここで悩むのは最後のtf.nn.dynamic_rnnのところで、同じ機能を持つ tf.nn.static_rnn の存在です。両方とも同じ機能を持ちますが、inputの形が異なるので何を最終的に使えば良いか悩んでしまいます。

でも、中身を見ると、両方ともコアのところはあまり差がないです。
すごくざっくり説明をすると、static_rnnはinputの前処理をユーザーに任して決められたグラフを書く反面、 dynamic_rnn は前処理などを全部内部で処理して、それを元にグラフを書く違いがあります。つまり、static_rnnはユーザーが最初にグラフの形に合わせてinputを作成する必要がありますが、dynamic_rnnは内部で inputデータ を分析して処理してくれるので、ユーザーがinputをグラフの形に合わせなくても良いことですね。
実際のコードを全部見せるには長すぎるので、そのコアの処理部分だけ紹介します。

まずは static_rnn

    
@tf_export("nn.static_rnn")
def static_rnn(cell,
               inputs,
               initial_state=None,
               dtype=None,
               sequence_length=None,
               scope=None):
    ...
    # いろんな前処理
    ...
    
    # ここがコアの処理部分
    for time, input_ in enumerate(inputs):
      if time > 0:
        varscope.reuse_variables()
      # pylint: disable=cell-var-from-loop
      call_cell = lambda: cell(input_, state)
      # pylint: enable=cell-var-from-loop
      if sequence_length is not None:
        (output, state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            min_sequence_length=min_sequence_length,
            max_sequence_length=max_sequence_length,
            zero_output=zero_output,
            state=state,
            call_cell=call_cell,
            state_size=cell.state_size)
      else:
        (output, state) = call_cell()
    
    ....

ご覧の通り、inputsを_rnn_stepかcall_cellにループさせてますね。

次は dynamic_rnn です。
dynamic_rnnの場合、コア処理は _dynamic_rnn_loop で行われてます。


@tf_export("nn.dynamic_rnn")
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
                dtype=None, parallel_iterations=None, swap_memory=False,
                time_major=False, scope=None):
  ... 
  # いろんな前処理
  ...
  # 前処理した後、_dynamic_rnn_loopを呼ぶ
  (outputs, final_state) = _dynamic_rnn_loop(
        cell,
        inputs,
        state,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory,
        sequence_length=sequence_length,
        dtype=dtype)
  ...

# コア処理はここで担当
def _dynamic_rnn_loop(cell,
                      inputs,
                      initial_state,
                      parallel_iterations,
                      swap_memory,
                      sequence_length=None,
                      dtype=None):
  ...
  # いろんな前処理
  ...
  # _rnn_stepをラッピングするinner function
  def _time_step(time, output_ta_t, state):
    ... 
    # いろんな処理
    ...

    call_cell = lambda: cell(input_t, state)

    if sequence_length is not None:
      (output, new_state) = _rnn_step(
          time=time,
          sequence_length=sequence_length,
          min_sequence_length=min_sequence_length,
          max_sequence_length=max_sequence_length,
          zero_output=zero_output,
          state=state,
          call_cell=call_cell,
          state_size=state_size,
          skip_conditionals=True)
    else:
      (output, new_state) = call_cell()
  ...

  # _time_step(_rnn_step)をループさせる。
  _, output_final_ta, final_state = control_flow_ops.while_loop(
      cond=lambda time, *_: time < loop_bound,
      body=_time_step,
      loop_vars=(time, output_ta, state),
      parallel_iterations=parallel_iterations,
      maximum_iterations=time_steps,
      swap_memory=swap_memory)
  ...

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py
ご覧の通り、static_rnnと違って処理が複雑になっていますが、コアのところは両方とも _rnn_stepを使って、似たような処理を行っています。なので、staticを使おうが、dynamicを使おうがそこまで違いはないと思いました。

でもやっぱり input の処理は dynamic_rnnの方が便利なので私は dynamic_rnnを使います。

6
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?