TensorflowでRNNを実装する時に、tf.nn.static_rnnとtf.nn.dynamic_rnnのどっちを使えば良いか、その違いは何かで悩んだことがあります。
結論から言うと、Tensorflow 1.7では tf.nn.static_rnn も tf.nn.dynamic_rnn もコアの処理はほぼ同じなので、便利な dynamic_rnn を使った方が良いってことです。
下記からはその詳細になります。
TensorflowでRNNを実装する時、下記のような実装になります。
# tensorflow/python/ops/rnn.py の def dynamic_rnnのコメントから引用
# Tensorflow version 1.7
# BasicRNNCellの作成
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
# initial state の作成
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# ここが問題
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
ここで悩むのは最後のtf.nn.dynamic_rnnのところで、同じ機能を持つ tf.nn.static_rnn の存在です。両方とも同じ機能を持ちますが、inputの形が異なるので何を最終的に使えば良いか悩んでしまいます。
でも、中身を見ると、両方ともコアのところはあまり差がないです。
すごくざっくり説明をすると、static_rnnはinputの前処理をユーザーに任して決められたグラフを書く反面、 dynamic_rnn は前処理などを全部内部で処理して、それを元にグラフを書く違いがあります。つまり、static_rnnはユーザーが最初にグラフの形に合わせてinputを作成する必要がありますが、dynamic_rnnは内部で inputデータ を分析して処理してくれるので、ユーザーがinputをグラフの形に合わせなくても良いことですね。
実際のコードを全部見せるには長すぎるので、そのコアの処理部分だけ紹介します。
まずは static_rnn
@tf_export("nn.static_rnn")
def static_rnn(cell,
inputs,
initial_state=None,
dtype=None,
sequence_length=None,
scope=None):
...
# いろんな前処理
...
# ここがコアの処理部分
for time, input_ in enumerate(inputs):
if time > 0:
varscope.reuse_variables()
# pylint: disable=cell-var-from-loop
call_cell = lambda: cell(input_, state)
# pylint: enable=cell-var-from-loop
if sequence_length is not None:
(output, state) = _rnn_step(
time=time,
sequence_length=sequence_length,
min_sequence_length=min_sequence_length,
max_sequence_length=max_sequence_length,
zero_output=zero_output,
state=state,
call_cell=call_cell,
state_size=cell.state_size)
else:
(output, state) = call_cell()
....
ご覧の通り、inputsを_rnn_stepかcall_cellにループさせてますね。
次は dynamic_rnn です。
dynamic_rnnの場合、コア処理は _dynamic_rnn_loop で行われてます。
@tf_export("nn.dynamic_rnn")
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
dtype=None, parallel_iterations=None, swap_memory=False,
time_major=False, scope=None):
...
# いろんな前処理
...
# 前処理した後、_dynamic_rnn_loopを呼ぶ
(outputs, final_state) = _dynamic_rnn_loop(
cell,
inputs,
state,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
sequence_length=sequence_length,
dtype=dtype)
...
# コア処理はここで担当
def _dynamic_rnn_loop(cell,
inputs,
initial_state,
parallel_iterations,
swap_memory,
sequence_length=None,
dtype=None):
...
# いろんな前処理
...
# _rnn_stepをラッピングするinner function
def _time_step(time, output_ta_t, state):
...
# いろんな処理
...
call_cell = lambda: cell(input_t, state)
if sequence_length is not None:
(output, new_state) = _rnn_step(
time=time,
sequence_length=sequence_length,
min_sequence_length=min_sequence_length,
max_sequence_length=max_sequence_length,
zero_output=zero_output,
state=state,
call_cell=call_cell,
state_size=state_size,
skip_conditionals=True)
else:
(output, new_state) = call_cell()
...
# _time_step(_rnn_step)をループさせる。
_, output_final_ta, final_state = control_flow_ops.while_loop(
cond=lambda time, *_: time < loop_bound,
body=_time_step,
loop_vars=(time, output_ta, state),
parallel_iterations=parallel_iterations,
maximum_iterations=time_steps,
swap_memory=swap_memory)
...
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py
ご覧の通り、static_rnnと違って処理が複雑になっていますが、コアのところは両方とも _rnn_stepを使って、似たような処理を行っています。なので、staticを使おうが、dynamicを使おうがそこまで違いはないと思いました。
でもやっぱり input の処理は dynamic_rnnの方が便利なので私は dynamic_rnnを使います。