TensorFlow LSTM状态从元组切换到Tensor

时间:2017-06-12 21:17:37

标签: python-3.x tensorflow

我正在将我的评论从https://github.com/tensorflow/tensorflow/issues/8833移到StackOverflow,因为这似乎更合适。

我正在尝试使用tensorflow.contrib.seq2seqtensorflow.contrib.rnn的{​​{1}}来实现序列模型。在BasicLSTMCell内,行rnn_cell_impl.py会导致以下错误:

c, h = state

当单步执行代码时,我了解到错误是在第三次TypeError: 'Tensor' object is not iterable.被评估时引起的。前两次,state的类型为c, h = state,但第三次,state的类型为<class 'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple'>。显然,我希望第三次使用LSTMStateTuple类型,但我不知道是什么原因导致切换。

有问题的<class 'tensorflow.python.framework.ops.Tensor'>张量名称为state。我编写了方法define_model/define_decoder/decoder/while/Identity_3define_model(),剩下的信息表明我的define_decoder()内发生了一些事情。

如果它是相关的,我使用的是Python 3.6和Tensorflow 1.2。

2 个答案:

答案 0 :(得分:1)

答案可以在上面linked Github issue page找到。

简而言之,问题是我的编码器使用双向RNN,它产生2元组的LSTMStateTuples,即每个定向RNN的一个c和一个h状态。然后,稍后,解码器接受单个单元,其与单个LSTMStateTuple相关联。要解决此问题,您需要单独连接双向RNNS的c状态和h状态,将其包装为新的LSTMStateTuple并将其传递给解码器的状态。

答案 1 :(得分:0)

我认为可以找到类似的答案here

代码将cudnn cell state转换为tensorflow内部状态。

参见此方法

def cudnn_lstm_state_to_state_tuples(cudnn_lstm_state):