Edited at

TensorFlow RNN関連のimportやRNN(LSTM)Cellでエラーが出た場合の対処(v 0.11r~)

More than 1 year has passed since last update.

以前TensorFlow RNNの初学用サンプルコードに書いた(TensorFlowのRNNを基本的なモデルで試す)のですが,TensorFlowのバージョンがあがり,動かなくなっている部分があるので簡単にまとめておきます.


環境

OSX

python 3.5

TensorFlow r0.11


import

RNN関連のinportを行う際,v.0.8までは

from tensorflow.models.rnn import rnn, rnn_cell

としていましたが,repackaging関連の問題で,

ImportError: This module is deprecated.  Use tf.nn.rnn_* instead.

というエラーが出てしまいます.

そのため,rnn, rnn_cellのimportの部分は消してしまい,rnnやrnn_cellを使用している部分を

# rnn_cell.BasicLSTMCell( ->

tf.nn.rnn_cell.BasicLSTMCell(...

# rnn( ->
tf.nn.rnn(...

と変更してあげてください.


BasicLSTMCell

上記エラーを修正しても


cell = tf.nn.rnn_cell.BasicLSTMCell(
num_of_hidden_nodes, forget_bias=forget_bias)
rnn_output, states_op = tf.nn.rnn(cell, in4, initial_state=istate_ph)

のような部分で,

TypeError: 'Tensor' object is not iterable.

というエラーが吐かれることがあります.

これは,BasicLSTMCellのデフォルト引数が変わったのが原因で,戻り値の形式としてtuple型を許すか,結合させるかを決めるstate_is_tupleがFalseからTrueに変わっているためです.

そのため,上記コードであれば


cell = tf.nn.rnn_cell.BasicLSTMCell(
num_of_hidden_nodes, forget_bias=forget_bias, state_is_tuple=False)
rnn_output, states_op = tf.nn.rnn(cell, in4, initial_state=istate_ph)

BasicLSTMCellの引数に,state_is_tuple=Falseを指定してあげればよいです.

ただし,これに関しては公式ドキュメントにて,そのうち廃止にするよ,と書いてあるのでゆくゆくはまた書き方を変える必要がありそうです.


参考