神经机器翻译模型预测是一个一个

时间:2018-01-15 02:55:28

标签: python tensorflow machine-learning recurrent-neural-network tensorflow-datasets

问题摘要

在以下示例中,我的NMT模型损失很大,因为它正确预测target_input而不是target_output

Targetin   :  1  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  9 10 10 10  3  3 10 10  3 10  3  3 10 10  3  9  9  4  4  4  4  4  3 10  3  3  9  9  3  6  6  6  6  6  6 10  9  9 10 10  4  4  4  4  4  4  4  4  4  4  4  4  9  9  9  9  3  3  3  6  6  6  6  6  9  9 10  3  4  4  4  4  4  4  4  4  4  4  4  4  9  9 10  3 10  9  9  3  4  4  4  4  4  4  4  4  4 10 10  4  4  4  4  4  4  4  4  4  4  9  9 10  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  9  3  3 10  6  6  6  6  6  3  9  9  3  3  3  3  3  3  3 10 10  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  5  3  3  3  3 10 10 10  3  9  9  5 10  3  3  3  3  9  9  9  5 10 10 10 10 10  4  4  4  4  3 10  6  6  6  6  6  6  3  5 10 10 10 10  3  9  9  6  6  6  6  6  6  6  6  6  9  9  9  3  3  3  6  6  6  6  6  6  6  6  3  9  9  9  3  3  6  6  6  3  3  3  3  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
Targetout  :  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  9 10 10 10  3  3 10 10  3 10  3  3 10 10  3  9  9  4  4  4  4  4  3 10  3  3  9  9  3  6  6  6  6  6  6 10  9  9 10 10  4  4  4  4  4  4  4  4  4  4  4  4  9  9  9  9  3  3  3  6  6  6  6  6  9  9 10  3  4  4  4  4  4  4  4  4  4  4  4  4  9  9 10  3 10  9  9  3  4  4  4  4  4  4  4  4  4 10 10  4  4  4  4  4  4  4  4  4  4  9  9 10  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  9  3  3 10  6  6  6  6  6  3  9  9  3  3  3  3  3  3  3 10 10  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  5  3  3  3  3 10 10 10  3  9  9  5 10  3  3  3  3  9  9  9  5 10 10 10 10 10  4  4  4  4  3 10  6  6  6  6  6  6  3  5 10 10 10 10  3  9  9  6  6  6  6  6  6  6  6  6  9  9  9  3  3  3  6  6  6  6  6  6  6  6  3  9  9  9  3  3  6  6  6  3  3  3  3  3  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
Prediction :  3  3  3  3  3  6  6  6  9  7  7  7  4  4  4  4  4  9  3  3  3  3  3  3 10  3  3 10  3  3 10  3  3  9  3  4  4  4  4  4  3 10  3  3  9  3  3  6  6  6  6  6  6 10  9  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3  3  3  3  6  6  6  6  6  9  6  3  3  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3 10  9  3  3  4  4  4  4  4  4  4  4  4  3 10  4  4  4  4  4  4  4  4  4  4  9  3  3  3  6  6  6  6  3  3  3 10  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  3  3 10  6  6  6  6  6  3  9  3  3  3  3  3  3  3  3  3  3  3  9  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  9  3  6  6  6  6  6  6  3  3  3  3  3  3 10  3  3  3  9  3  3 10  3  3  3  3  9  3  9  3 10  3  3  3  3  4  4  4  4  3 10  6  6  6  6  6  6  3  3 10  3  3  3  3  9  3  6  6  6  6  6  6  6  6  6  9  6  9  3  3  3  6  6  6  6  6  6  6  6  3  9  3  9  3  3  6  6  6  3  3  3  3  3  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6
Source     :  9 16  4  7 22 22 19  1 12 19 12 18  5 18  9 18  5  8 12 19 19  5  5 19 22  7 12 12  6 19  7  3 20  7  9 14  4 11 20 12  7  1 18  7  7  5 22  9 13 22 20 19  7 19  7 13  7 11 19 20  6 22 18 17 17  1 12 17 23  7 20  1 13  7 11 11 22  7 12  1 13 12  5  5 19 22  5  5 20  1  5  4 12  9  7 12  8 14 18 22 18 12 18 17 19  4 19 12 11 18  5  9  9  5 14  7 11  6  4 17 23  6  4  5 12  6  7 14  4 20  6  8 12 25  4 19  6  1  5  1  5 20  4 18 12 12  1 11 12  1 25 13 18 19  7 12  7  3  4 22  9  9 12  4  8  9 19  9 22 22 19  1 19  7  5 19  4  5 18 11 13  9  4 14 12 13 20 11 12 11  7  6  1 11 19 20  7 22 22 12 22 22  9  3  8 12 11 14 16  4 11  7 11  1  8  5  5  7 18 16 22 19  9 20  4 12 18  7 19  7  1 12 18 17 12 19  4 20  9  9  1 12  5 18 14 17 17  7  4 13 16 14 12 22 12 22 18  9 12 11  3 18  6 20  7  4 20  7  9  1  7 25 13  5 25 14 11  5 20  7 23 12  5 16 19 19 25 19  7 -1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0

很明显,预测与target_input而不是target_output匹配几乎100%,因为它应该(逐个)。使用target_output计算损失和梯度,因此奇怪的是预测与target_input匹配。

模型概述

NMT模型使用源语言中的主要单词序列预测目标语言中的单词序列。这是谷歌翻译背后的框架。由于NMT使用耦合RNN,因此它受到监督并且需要标记的目标输入和输出。

NMT使用source序列,target_input序列和target_output序列。在下面的示例中,编码器RNN(蓝色)使用源输入字来产生意义向量,其传递给解码器RNN(红色),其使用意义向量来产生输出。

当进行新的预测(推断)时,解码器RNN使用其自己的先前输出来在时间步中播种下一个预测。然而,为了改进训练,允许在每个新的时间步长处使用正确的先前预测来自我种子。这就是培训需要target_input的原因。

enter image description here

使用source,target_in,target_out 获取迭代器的代码

def get_batched_iterator(hparams, src_loc, tgt_loc):
    if not (os.path.exists('primary.csv') and os.path.exists('secondary.csv')):
        utils.integerize_raw_data()

    source_dataset = tf.data.TextLineDataset(src_loc)
    target_dataset = tf.data.TextLineDataset(tgt_loc)
    dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
    dataset = dataset.shuffle(hparams.shuffle_buffer_size, seed=hparams.shuffle_seed)

    dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
                                                  tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
    dataset = dataset.map(lambda source, target: (source, tf.concat(([hparams.sos], target), axis=0), tf.concat((target, [hparams.eos]), axis=0)))
    dataset = dataset.map(lambda source, target_in, target_out: (source, target_in, target_out, tf.size(source), tf.size(target_in)))
    # Proceed to batch and return iterator

NMT模型核心代码

def __init__(self, hparams, iterator, mode):
        source, target_in, target_out, source_lengths, target_lengths = iterator.get_next()

        # Lookup embeddings
        embedding_encoder = tf.get_variable("embedding_encoder", [hparams.src_vsize, hparams.src_emsize])
        encoder_emb_inp = tf.nn.embedding_lookup(embedding_encoder, source)
        embedding_decoder = tf.get_variable("embedding_decoder", [hparams.tgt_vsize, hparams.tgt_emsize])
        decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, target_in)

        # Build and run Encoder LSTM
        encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, sequence_length=source_lengths, dtype=tf.float32)

        # Build and run Decoder LSTM with TrainingHelper and output projection layer
        decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(hparams.num_units)
        projection_layer = layers_core.Dense(hparams.tgt_vsize, use_bias=False)
        helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, sequence_length=target_lengths)
        decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=projection_layer)
        outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
        logits = outputs.rnn_output

        if mode is 'TRAIN' or mode is 'EVAL':  # then calculate loss
            crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_out, logits=logits)
            target_weights = tf.sequence_mask(target_lengths, maxlen=tf.shape(target_out)[1], dtype=logits.dtype)
            self.loss = tf.reduce_sum((crossent * target_weights) / hparams.batch_size)

        if mode is 'TRAIN':  # then calculate/clip gradients, then optimize model
            params = tf.trainable_variables()
            gradients = tf.gradients(self.loss, params)
            clipped_gradients, _ = tf.clip_by_global_norm(gradients, hparams.max_gradient_norm)

            optimizer = tf.train.AdamOptimizer(hparams.l_rate)
            self.update_step = optimizer.apply_gradients(zip(clipped_gradients, params))

        if mode is 'EVAL':  # then allow access to input/output tensors to printout
            self.src = source
            self.tgt_in = target_in
            self.tgt_out = target_out
            self.logits = logits

完整代码(无需解决此问题) https://github.com/nave01314/tf-nmt

1 个答案:

答案 0 :(得分:1)

用于预测具有重复结构的类似语言的语法的NMT模型的核心问题是,它被激励以简单地预测过去的预测。由于TrainingHelper通过(personId, closed)为每个步骤提供正确的先前预测以加速训练,因此人为地产生模型无法摆脱的局部最小值。

我发现的最佳选择是对损失函数进行加权,例如输出序列中输出不重复的关键点的权重更大。这将激励模型获得正确的,而不仅仅是重复 过去的预测。