環境
Windows 10
Python 3.6.2
Tensorflow 1.5.0
エラー内容
TensorflowにてRNNを構築して、実行時に下記のエラーが出ました。
- コード
error.py
n_hidden = 20
n_batch = tf.placeholder(tf.int32)
cell = tf.nn.rnn_cell.BasicRNNCell(n_hidden)
initial_state = cell.zero_state(n_batch, tf.float32)
- エラー
ValueError: prefix tensor must be either a scalar or vector, but saw tensor: Tensor("Placeholder_2:0", dtype=int32)
対処
zero_state
に渡しているn_batch
のshapeが未指定ということらしいです。
以下のように対処します。
n_batch = tf.placeholder(tf.int32, shape=[])
参考
BasicLSTMCell zero_state() raise error in TF1.2 but works in TF1.1