在Tensorflow中生成特殊输出字后如何停止RNN?

时间:2016-02-02 07:40:50

标签: tensorflow recurrent-neural-network

我想实现序列到序列学习的编码器 - 解码器模型。

编码器逐字读取输入序列并更新其隐藏状态。

解码器使用编码器的隐藏状态来初始化其隐藏状态。然后生成关于最后生成的输出(y(t-1))及其隐藏状态的输出。我想在生成特殊输出()时停止此过程。实际上,我希望能够生成不同长度的输出。我怎么能在Tensorflow中做到这一点?

1 个答案:

答案 0 :(得分:1)

我想你想要像sequence_length tf.nn.rnn这样的东西。我也想要它,但似乎TensorFlow没有它。

到目前为止我正在做的事情并找到了解决此限制的好方法是在列车时间使用EOS符号填充解码器标签。通常,您只需要其中一种,但填充其中很多都不会造成任何伤害。

在执行时,您可以手动控制每次迭代,以便在生成第一个EOS时停止,或者只运行预定义的时间步数,然后从输出中删除额外的EOS符号。解码器很快得知在第一个EOS之后,只有更多的EOS可以跟随。