如何在tf.nn.dynamic_rnn中为LSTM初始化intitial_state?

时间:2018-02-27 15:59:37

标签: python tensorflow lstm

当单元格是LSTMCell时,我不确定如何传递initial_state的值。我正在使用LSTMStateTuple,因为它显示在下面的代码中:

c_placeholder = tf.placeholder(tf.float32, [ None, config.state_dim], name='c_lstm')

h_placeholder = tf.placeholder(tf.float32, [ None, config.state_dim], name='h_lstm')

state_tuple = tf.nn.rnn_cell.LSTMStateTuple(c_placeholder, h_placeholder)

cell = tf.contrib.rnn.LSTMCell(num_units=config.state_dim, state_is_tuple=True, reuse=not is_training)  

rnn_outs, states = tf.nn.dynamic_rnn(cell=cell, inputs=x,sequence_length=seqlen, initial_state=state_tuple, dtype= tf.float32)

但是,执行会返回此错误:

TypeError: 'Tensor' object is not iterable.

以下是dynamic_rnn

文档的链接

1 个答案:

答案 0 :(得分:0)

我以前见过同样的错误。我正在使用由tf.contrib.rnn.MultiRNNCell制作的多层RNN单元格,我需要指定一个LSTMStateTuples元组 - 每层一个。像

这样的东西
state = tuple(
        [tf.nn.rnn_cell.LSTMStateTuple(c_ph[i], h_ph[i])
         for i in range(nRecurrentLayers)]
    )