如何提高张量流中的rnn精度?

时间:2016-12-20 05:23:47

标签: tensorflow deep-learning lstm

我正在使用张量流seq2seq.rnn_decoder进行标题自动生成项目。

我的训练集是一大堆标题,每个标题彼此独立且不相关。

我尝试过两种数据格式进行培训:

F1. Use the fixed seq length in batch, and replace ‘\n’ to ‘<eos>’, and ‘<eos>’ index is 1, which training batch is like: [2,3,4,5,8,9,1,2,3,4], [88,99,11,90,1,5,6,7,8,10]
F2. Use Variable seq length in batch, and add PAD 0 to keep the fixed length, which training batch is like: [2,3,4,5,8,9,0,0,0,0], [2,3,4,88,99,90,11,0,0,0]

然后我在一个有10,000个标题的小集中进行测试,但结果让我感到困惑。

F1在单个单词中做出了很好的预测,如下所示:

iphone predict 6
samsung predict galaxy
case predict cover

如果输入是从句子的第一个单词开始,则F2在长句中做出很好的预测,多次预测几乎等于原始句子。

但是,如果起始单词来自句子的中间(或接近结尾),则F2的预测非常非常糟糕,就像随机结果一样。

这种情况是否与隐藏状态有关?

在训练阶段,我在新纪元开始时将隐藏状态重置为0,所以在纪元中所有批次都将使用相同的隐藏状态,我怀疑这不是一个好习惯,因为每个句子实际上都是独立的,它是否可以在训练中共享相同的隐藏状态?

在推断阶段,初始隐藏状态为0,&amp;在提供单词时更新。 (清除输入时复位为0)

所以我的问题是,为什么当开始单词是从句子的中间(或接近结尾)开始时,F2的预测是不好的?在项目中更新隐藏状态的正确方法是什么?

1 个答案:

答案 0 :(得分:0)

我不确定我是否正确理解了您的设置,但我认为您所看到的情况是预期的,并且与处理隐藏状态有关。

让我们先看看你在F2中看到的内容。由于您每次都重置隐藏状态,因此网络只在整个标题的开头看到0状态,对吧?因此,在训练期间,除了启动序列之外,它可能永远不会有0状态。当你尝试从中间解码时,你从0状态开始,在训练期间从未见过这样的位置,所以它失败了。

在F1中,你也重置状态,但由于你没有填充,0状态在训练期间更随机 - 有时在开始时,有时在标题的中间。网络学会了应对这一点。