Seq2Seq模型学习几次迭代后仅输出EOS令牌(<\ s>)

时间:2018-09-28 21:34:10

标签: python tensorflow lstm recurrent-neural-network seq2seq

我正在使用Cornell Movie Dialogs Corpus创建在NMT上受过训练的聊天机器人。

我的代码部分基于https://github.com/bshao001/ChatLearnerhttps://github.com/chiphuyen/stanford-tensorflow-tutorials/tree/master/assignments/chatbot

在训练过程中,我打印了从批处理中馈送到解码器的随机输出答案以及模型预测观察到的学习进度的相应答案。

我的问题:经过大约4次迭代训练,该模型学会了在每个时间步输出EOS令牌(<\s>)。即使训练继续进行,它也始终将其输出作为其响应(由logits的argmax确定)。该模型偶尔会偶尔输出一系列周期作为其答案。

我还在训练过程中打印了前10个logit值(不仅仅是argmax),以查看其中是否有正确的单词,但它似乎正在预测词汇中最常见的单词(例如,i, ?,)。在培训期间,即使是前10个字词也没有太大变化。

我确保正确计算编码器和解码器的输入序列长度,并相应地添加了SOS(<s>)和EOS(也用于填充)令牌。我还在损失计算中执行掩蔽

以下是示例输出:

培训迭代1:

Decoder Input: <s> sure . sure . <\s> <\s> <\s> <\s> <\s> <\s> <\s> 
<\s> <\s>
Predicted Answer: wildlife bakery mentality mentality administration 
administration winston winston winston magazines magazines magazines 
magazines

...

训练迭代4:

Decoder Input: <s> i guess i had it coming . let us call it settled . 
<\s> <\s> <\s> <\s> <\s>
Predicted Answer: <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> 
<\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s>


再经过几次迭代后,它只依靠预测EOS(很少出现周期)

我不确定是什么原因导致了此问题,并且已经将其停留了一段时间。任何帮助将不胜感激!

更新:我让它进行了十万次迭代训练,但它仍然仅输出EOS(偶尔出现)。经过几次迭代后,训练损失也不会减少(从一开始就保持在47左右)

2 个答案:

答案 0 :(得分:0)

最近,我也在seq2seq模型上工作。 在遇到我的情况之前,我是通过更改损失函数来解决问题的。

您说过使用口罩,所以我猜您像以前一样使用tf.contrib.seq2seq.sequence_loss

我改为tf.nn.softmax_cross_entropy_with_logits,它可以正常工作(并且计算成本更高)。

(编辑05/10/2018。对不起,我发现我的代码存在严重错误,因此我需要进行编辑)

如果tf.contrib.seq2seq.sequence_losslogitstargets的形状正确,

mask可以很好地工作。 根据官方文件中的定义: tf.contrib.seq2seq.sequence_loss

loss=tf.contrib.seq2seq.sequence_loss(logits=decoder_logits,
                                      targets=decoder_targets,
                                      weights=masks) 

#logits:  [batch_size, sequence_length, num_decoder_symbols]  
#targets: [batch_size, sequence_length] 
#weights: [batch_size, sequence_length] 

嗯,即使形状不符合,它仍然可以工作。但是结果可能很奇怪(很多#EOS #PAD ...等)。

由于decoder_outputsdecoder_targets的形状可能与要求的形状相同(在我的情况下,我的decoder_targets的形状为[sequence_length, batch_size])。 因此,尝试使用tf.transpose帮助您重塑张量。

答案 1 :(得分:0)

在我的例子中,这是由于优化器,我错误地设置了一个大的 lr_decay,使其不再正常工作。

检查 Lr 和优化器/调度器可能会有所帮助。