3
5

More than 5 years have passed since last update.

TensorFlow RNN関連のimportやRNN(LSTM)Cellでエラーが出た場合の対処(v 0.11r~)

Last updated at Posted at 2016-11-20

以前TensorFlow RNNの初学用サンプルコードに書いた(TensorFlowのRNNを基本的なモデルで試す)のですが,TensorFlowのバージョンがあがり,動かなくなっている部分があるので簡単にまとめておきます.

環境

OSX
python 3.5
TensorFlow r0.11

import

RNN関連のinportを行う際,v.0.8までは

from tensorflow.models.rnn import rnn, rnn_cell

としていましたが,repackaging関連の問題で,

ImportError: This module is deprecated.  Use tf.nn.rnn_* instead.

というエラーが出てしまいます.

そのため,rnn, rnn_cellのimportの部分は消してしまい,rnnやrnn_cellを使用している部分を

# rnn_cell.BasicLSTMCell( ->
tf.nn.rnn_cell.BasicLSTMCell(...

# rnn( ->
tf.nn.rnn(...

と変更してあげてください.

BasicLSTMCell

上記エラーを修正しても


cell = tf.nn.rnn_cell.BasicLSTMCell(
    num_of_hidden_nodes, forget_bias=forget_bias)
rnn_output, states_op = tf.nn.rnn(cell, in4, initial_state=istate_ph)

のような部分で,

TypeError: 'Tensor' object is not iterable.

というエラーが吐かれることがあります.

これは,BasicLSTMCellのデフォルト引数が変わったのが原因で,戻り値の形式としてtuple型を許すか,結合させるかを決めるstate_is_tupleがFalseからTrueに変わっているためです.

そのため,上記コードであれば


cell = tf.nn.rnn_cell.BasicLSTMCell(
    num_of_hidden_nodes, forget_bias=forget_bias, state_is_tuple=False)
rnn_output, states_op = tf.nn.rnn(cell, in4, initial_state=istate_ph)

BasicLSTMCellの引数に,state_is_tuple=Falseを指定してあげればよいです.
ただし,これに関しては公式ドキュメントにて,そのうち廃止にするよ,と書いてあるのでゆくゆくはまた書き方を変える必要がありそうです.

参考

3
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
5