Tensorflow MultiRNNCell LSTM:“'Tensor'对象不可迭代”

时间:2018-07-18 14:53:44

标签: python tensorflow deep-learning lstm recurrent-neural-network

我有以下代码:

self.inputs = tf.placeholder(shape=[None, num_inputs], dtype=tf.float32)

# Recurrent network for temporal dependencies
def make_cell(units):
    cell = tf.contrib.rnn.BasicLSTMCell(units, state_is_tuple=True)
    if mode == TRAIN and keep_prob < 1:
        cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)
        return cell

num_units = [h_size, h_size]
multi_rnn_cell = tf.contrib.rnn.MultiRNNCell(
    [make_cell(n) for n in num_units], state_is_tuple=True)

self.state_init = multi_rnn_cell.zero_state(1, tf.float32)
h_in = tf.placeholder(tf.float32, shape=[1, h_size])
c_in = tf.placeholder(tf.float32, shape=[1, h_size])
self.state_in = (c_in, h_in)
state_in = tf.contrib.rnn.LSTMStateTuple(c_in, h_in)

rnn_in = tf.expand_dims(self.inputs, [0])
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
    cell=multi_rnn_cell, inputs=rnn_in, initial_state=state_in, dtype=tf.float32)

运行它时,出现以下错误:

  

TypeError:未启用急切执行时,张量对象不可迭代。要遍历此张量,请使用tf.map_fn。

当我运行相同的代码,但只有1个LSTMCell并且没有扩展尺寸时,它运行得很好。

我想通过添加MultiRNNCell来使用不止一层。

有人可以帮我吗?

0 个答案:

没有答案