我一直在关注Tensorflow的NMT教程,并且在没有引起注意的情况下正确实现了该模型。但是,当我试图增加对模型的关注时,我的损失变得疯狂了。对于前20个时期,它会从大约6个减少到3.5个,但是此后,它将一路飙升至100个。我已经查看了GitHub上的许多不同实现,但我发现我之间没有任何区别代码和实际有效的代码。任何帮助将不胜感激。我的tensorflow版本是1.5.0
with tf.name_scope("Encoder_Vars"):
encoder_input = tf.placeholder(shape = [None, None], dtype = tf.int32, name = 'encoder_input')
encoder_seq_len = tf.placeholder(shape = [None], dtype = tf.int32, name = 'encoder_seq_len')
encoder_embedding_matrix = tf.get_variable("encoder_embedding_matrix", [spa_vocab_size, embedding_dim])
encoder_emb_input = tf.nn.embedding_lookup(encoder_embedding_matrix, encoder_input)
tf.summary.histogram("enc_emb", encoder_embedding_matrix)
with tf.name_scope("Decoder_Vars"):
decoder_input = tf.placeholder(shape = [None, batch_size], dtype = tf.int32, name = 'decoder_input')
decoder_seq_len = tf.placeholder(shape = [batch_size], dtype = tf.int32, name = 'decoder_seq_len')
decoder_embedding_matrix = tf.get_variable("decoder_embedding_matrix", [spa_vocab_size, embedding_dim])
decoder_emb_input = tf.nn.embedding_lookup(decoder_embedding_matrix, decoder_input)
tf.summary.histogram("dec_emb", decoder_embedding_matrix)
with tf.name_scope("Encoder"):
encoder_lstm = tf.nn.rnn_cell.BasicLSTMCell(latent_dim, name = 'encoder_lstm')
enc_outputs, enc_states = tf.nn.dynamic_rnn(encoder_lstm, encoder_emb_input, sequence_length = encoder_seq_len, time_major = True, dtype = tf.float32)
tf.summary.histogram("enc_lstm_kernel", encoder_lstm.variables[0])
with tf.name_scope("Decoder_with_Attention"):
attention_states = tf.transpose(enc_outputs, [1,0,2])
attention_mechanism = tf.contrib.seq2seq.LuongAttention(latent_dim, attention_states, memory_sequence_length = encoder_seq_len)
decoder_lstm = tf.contrib.seq2seq.AttentionWrapper(tf.nn.rnn_cell.BasicLSTMCell(latent_dim, name = 'decoder_lstm'), attention_mechanism, attention_layer_size = latent_dim)
helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_input, decoder_seq_len, time_major = True)
projection_layer = Dense(spa_vocab_size, use_bias = False, name = 'projection_layer')
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_lstm, helper, decoder_lstm.zero_state(dtype = tf.float32, batch_size = batch_size), output_layer = projection_layer)
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
outputs = outputs.rnn_output