我有以下代码:
self.inputs = tf.placeholder(shape=[None, num_inputs], dtype=tf.float32)
# Recurrent network for temporal dependencies
def make_cell(units):
cell = tf.contrib.rnn.BasicLSTMCell(units, state_is_tuple=True)
if mode == TRAIN and keep_prob < 1:
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)
return cell
num_units = [h_size, h_size]
multi_rnn_cell = tf.contrib.rnn.MultiRNNCell(
[make_cell(n) for n in num_units], state_is_tuple=True)
self.state_init = multi_rnn_cell.zero_state(1, tf.float32)
h_in = tf.placeholder(tf.float32, shape=[1, h_size])
c_in = tf.placeholder(tf.float32, shape=[1, h_size])
self.state_in = (c_in, h_in)
state_in = tf.contrib.rnn.LSTMStateTuple(c_in, h_in)
rnn_in = tf.expand_dims(self.inputs, [0])
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
cell=multi_rnn_cell, inputs=rnn_in, initial_state=state_in, dtype=tf.float32)
运行它时,出现以下错误:
TypeError:未启用急切执行时,张量对象不可迭代。要遍历此张量,请使用tf.map_fn。
当我运行相同的代码,但只有1个LSTMCell并且没有扩展尺寸时,它运行得很好。
我想通过添加MultiRNNCell来使用不止一层。
有人可以帮我吗?