具有RNNCell的TensorFlow attention_decoder(state_is_tuple = True)

时间:2016-06-26 10:50:54

标签: tensorflow

我想构建一个带有attention_decoder的seq2seq模型,并使用MultiRNNCell和LSTMCell作为编码器。因为TensorFlow代码表明"这个默认行为(state_is_tuple = False)很快就会被弃用。",我为编码器设置了state_is_tuple = True。

问题在于,当我将编码器的状态传递给attention_decoder时,它报告错误:

*** AttributeError: 'LSTMStateTuple' object has no attribute 'get_shape'

这个问题似乎与seq2seq.py中的attention()函数和rnn_cell.py中的_linear()函数有关,其中代码调用了' get_shape()' LSTMStateTuple'的功能来自编码器生成的initial_state的对象。

虽然当我为编码器设置state_is_tuple = False时错误消失,但程序会发出以下警告:

WARNING:tensorflow:<tensorflow.python.ops.rnn_cell.LSTMCell object at 0x11763dc50>: Using a concatenated state is slower and will soon be deprecated.  Use state_is_tuple=True.

如果有人可以提供有关使用RNNCell(state_is_tuple = True)构建seq2seq的任何说明,我将非常感激。

1 个答案:

答案 0 :(得分:0)

我也遇到了这个问题,lstm状态需要连接,否则_linear会抱怨。 LSTMStateTuple的形状取决于您使用的细胞类型。使用LSTM单元,您可以连接这样的状态:

 query = tf.concat(1,[state[0], state[1]])

如果您正在使用MultiRNNCell,请先连接每个图层的状态:

 concat_layers = [tf.concat(1,[c,h]) for c,h in state]
 query = tf.concat(1, concat_layers)