具有注意实现错误的Tensorflow NMT

时间:2018-07-23 18:50:50

标签: tensorflow rnn attention-model

我一直在关注Tensorflow的NMT教程,并且在没有引起注意的情况下正确实现了该模型。但是,当我试图增加对模型的关注时,我的损失变得疯狂了。对于前20个时期,它会从大约6个减少到3.5个,但是此后,它将一路飙升至100个。我已经查看了GitHub上的许多不同实现,但我发现我之间没有任何区别代码和实际有效的代码。任何帮助将不胜感激。我的tensorflow版本是1.5.0

with tf.name_scope("Encoder_Vars"):
    encoder_input = tf.placeholder(shape = [None, None], dtype = tf.int32, name = 'encoder_input')
    encoder_seq_len = tf.placeholder(shape = [None], dtype = tf.int32, name = 'encoder_seq_len')
    encoder_embedding_matrix = tf.get_variable("encoder_embedding_matrix", [spa_vocab_size, embedding_dim])
    encoder_emb_input = tf.nn.embedding_lookup(encoder_embedding_matrix, encoder_input)
    tf.summary.histogram("enc_emb", encoder_embedding_matrix)

with tf.name_scope("Decoder_Vars"):
    decoder_input = tf.placeholder(shape = [None, batch_size], dtype = tf.int32, name = 'decoder_input')
    decoder_seq_len = tf.placeholder(shape = [batch_size], dtype = tf.int32, name = 'decoder_seq_len')
    decoder_embedding_matrix = tf.get_variable("decoder_embedding_matrix", [spa_vocab_size, embedding_dim])
    decoder_emb_input = tf.nn.embedding_lookup(decoder_embedding_matrix, decoder_input)
    tf.summary.histogram("dec_emb", decoder_embedding_matrix)

with tf.name_scope("Encoder"):
    encoder_lstm = tf.nn.rnn_cell.BasicLSTMCell(latent_dim, name = 'encoder_lstm')
    enc_outputs, enc_states = tf.nn.dynamic_rnn(encoder_lstm, encoder_emb_input, sequence_length = encoder_seq_len, time_major = True, dtype = tf.float32)
    tf.summary.histogram("enc_lstm_kernel", encoder_lstm.variables[0])

with tf.name_scope("Decoder_with_Attention"):
    attention_states = tf.transpose(enc_outputs, [1,0,2])
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(latent_dim, attention_states, memory_sequence_length = encoder_seq_len)

    decoder_lstm = tf.contrib.seq2seq.AttentionWrapper(tf.nn.rnn_cell.BasicLSTMCell(latent_dim, name = 'decoder_lstm'), attention_mechanism, attention_layer_size = latent_dim)
    helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_input, decoder_seq_len, time_major = True)
    projection_layer = Dense(spa_vocab_size, use_bias = False, name = 'projection_layer')
    decoder = tf.contrib.seq2seq.BasicDecoder(decoder_lstm, helper, decoder_lstm.zero_state(dtype = tf.float32, batch_size = batch_size), output_layer = projection_layer)

    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
    outputs = outputs.rnn_output

0 个答案:

没有答案