以前TensorFlow RNNの初学用サンプルコードに書いた(TensorFlowのRNNを基本的なモデルで試す)のですが,TensorFlowのバージョンがあがり,動かなくなっている部分があるので簡単にまとめておきます.
環境
OSX
python 3.5
TensorFlow r0.11
import
RNN関連のinportを行う際,v.0.8までは
from tensorflow.models.rnn import rnn, rnn_cell
としていましたが,repackaging関連の問題で,
ImportError: This module is deprecated. Use tf.nn.rnn_* instead.
というエラーが出てしまいます.
そのため,rnn, rnn_cellのimportの部分は消してしまい,rnnやrnn_cellを使用している部分を
# rnn_cell.BasicLSTMCell( ->
tf.nn.rnn_cell.BasicLSTMCell(...
# rnn( ->
tf.nn.rnn(...
と変更してあげてください.
BasicLSTMCell
上記エラーを修正しても
cell = tf.nn.rnn_cell.BasicLSTMCell(
num_of_hidden_nodes, forget_bias=forget_bias)
rnn_output, states_op = tf.nn.rnn(cell, in4, initial_state=istate_ph)
のような部分で,
TypeError: 'Tensor' object is not iterable.
というエラーが吐かれることがあります.
これは,BasicLSTMCell
のデフォルト引数が変わったのが原因で,戻り値の形式としてtuple型を許すか,結合させるかを決めるstate_is_tuple
がFalseからTrueに変わっているためです.
そのため,上記コードであれば
cell = tf.nn.rnn_cell.BasicLSTMCell(
num_of_hidden_nodes, forget_bias=forget_bias, state_is_tuple=False)
rnn_output, states_op = tf.nn.rnn(cell, in4, initial_state=istate_ph)
とBasicLSTMCell
の引数に,state_is_tuple=False
を指定してあげればよいです.
ただし,これに関しては公式ドキュメントにて,そのうち廃止にするよ,と書いてあるのでゆくゆくはまた書き方を変える必要がありそうです.
参考
- Tensor Flow - LSTM - 'Tensor' object not iterable http://stackoverflow.com/questions/40464178/tensor-flow-lstm-tensor-object-not-iterable
- 公式ドキュメント https://www.tensorflow.org/versions/r0.11/api_docs/python/rnn_cell.html#BasicLSTMCell