在没有老师强迫的情况下使用LSTM解码器-Tensorflow

时间:2018-08-02 16:50:40

标签: python tensorflow artificial-intelligence lstm seq2seq

我正在尝试在Tensorflow中构建序列模型的序列,我遵循了一些教程,一切都很好。直到我决定取消模型中老师的强迫为止。 以下是我正在使用的解码器网络的示例:

def decoding_layer_train(encoder_state, dec_cell, dec_embed_input, 
                     target_sequence_length, max_summary_length, 
                     output_layer, keep_prob):
"""
Create a decoding layer for training
:param encoder_state: Encoder State
:param dec_cell: Decoder RNN Cell
:param dec_embed_input: Decoder embedded input
:param target_sequence_length: The lengths of each sequence in the target batch
:param max_summary_length: The length of the longest sequence in the batch
:param output_layer: Function to apply the output layer
:param keep_prob: Dropout keep probability
:return: BasicDecoderOutput containing training logits and sample_id
"""

training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=dec_embed_input,
                                                    sequence_length=target_sequence_length,
                                                    time_major=False)

training_decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell, training_helper, encoder_state, output_layer)

training_decoder_output = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                            impute_finished=True,
                                                            maximum_iterations=max_summary_length)[0]
return training_decoder_output

据我了解,TrainingHelper正在执行教师强迫。特别是它将真实的输出作为其参数的一部分。我尝试在没有培训帮助的情况下使用解码器,但这似乎是强制性的。我尝试将true输出设置为0,但显然TrainingHelper需要该输出。我也尝试过用Google搜索解决方案,但没有找到任何相关内容。

====================更新===========

很抱歉没有在前面提到它,但是我也尝试使用GreedyEmbeddingHelper。该模型运行了几次迭代就很好了,然后开始抛出运行时错误。看来GreedyEmbeddingHelper开始预测与预期形状不同的输​​出。以下是使用GreedyEmbeddingHelper

时的功能
def decoding_layer_train(encoder_state, dec_cell, dec_embeddings, 
                         target_sequence_length, max_summary_length, 
                         output_layer, keep_prob):
    """
    Create a decoding layer for training
    :param encoder_state: Encoder State
    :param dec_cell: Decoder RNN Cell
    :param dec_embed_input: Decoder embedded input
    :param target_sequence_length: The lengths of each sequence in the target batch
    :param max_summary_length: The length of the longest sequence in the batch
    :param output_layer: Function to apply the output layer
    :param keep_prob: Dropout keep probability
    :return: BasicDecoderOutput containing training logits and sample_id
    """

    start_tokens = tf.tile(tf.constant([target_vocab_to_int['<GO>']], dtype=tf.int32), [batch_size], name='start_tokens')


    training_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(dec_embeddings,
                                                                start_tokens,
                                                                target_vocab_to_int['<EOS>'])

    training_decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell, training_helper, encoder_state, output_layer)

    training_decoder_output = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                                impute_finished=True,
                                                                maximum_iterations=max_summary_length)[0]
    return training_decoder_output

这是在一次训练迭代后抛出的错误示例:

    Ok

Epoch   0 Batch    5/91 - Train Accuracy: 0.4347, Validation Accuracy: 0.3557, Loss: 2.8656
++++Epoch   0 Batch    5/91 - Train WER: 1.0000, Validation WER: 1.0000

Epoch   0 Batch   10/91 - Train Accuracy: 0.4050, Validation Accuracy: 0.3864, Loss: 2.6347
++++Epoch   0 Batch   10/91 - Train WER: 1.0000, Validation WER: 1.0000

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-115-1d2a9495ad42> in <module>()
     57                  target_sequence_length: targets_lengths,
     58                  source_sequence_length: sources_lengths,
---> 59                  keep_prob: keep_probability})
     60 
     61 

/Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    887     try:
    888       result = self._run(None, fetches, feed_dict, options_ptr,
--> 889                          run_metadata_ptr)
    890       if run_metadata:
    891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1116     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1117       results = self._do_run(handle, final_targets, final_fetches,
-> 1118                              feed_dict_tensor, options, run_metadata)
   1119     else:
   1120       results = []

/Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1313     if handle is None:
   1314       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1315                            options, run_metadata)
   1316     else:
   1317       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

/Users/alsulaimi/Documents/AI/Tensorflow-make/workspace/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
   1332         except KeyError:
   1333           pass
-> 1334       raise type(e)(node_def, op, message)
   1335 
   1336   def _extend_graph(self):

InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [1100,78] and labels shape [1400]

我不确定,但是我猜GreedyEmbeddingHepler不应该用于培训。 ,感谢您的帮助和对停止老师强迫的思考。

谢谢。

1 个答案:

答案 0 :(得分:1)

有不同的助手,它们都从同一个类继承。您可以在documentation中找到更多信息。如您所说,TrainingHelper需要预定义的真实输入,这些输入应从解码器输出,并且此真实输入将作为下一步输入(而不是输入上一步的输出)。通过一些研究,这种方法应该可以加快解码器的训练速度。

在您的情况下,您正在寻找GreedyEmbeddingHelper。只需将其替换为TrainingHelper即可:

training_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
    embedding=embedding,
    start_tokens=tf.tile([GO_SYMBOL], [batch_size]),
    end_token=END_SYMBOL)

只需将其替换为在问题中使用的embedding张量和变量。该帮助程序自动获取应用嵌入的步骤的输出,并将其作为下一步的输入。第一步使用start_token

使用GreedyEmbeddingHelper生成的输出不必与预期输出的长度匹配。您必须使用填充来匹配其形状。 TensorFlow提供功能tf.pad()。另外,tf.contrib.seq2seq.dynamic_decode返回包含(final_outputs, final_state, final_sequence_lengths)的元组,因此您可以使用值final_sequece_lengths进行填充。

logits_pad = tf.pad(
    logits,
    [[0, tf.maximum(expected_length - tf.reduce_max(final_seq_lengths), 0)],
     [0, 0]],
    constant_values=PAD_VALUE,
    mode='CONSTANT')

targets_pad = tf.pad(
    targets,
    [[0, tf.maximum(tf.reduce_max(final_seq_lengths) - expected_length, 0)]],
    constant_values=PAD_VALUE,
    mode='CONSTANT')

您可能需要根据输入的形状稍微更改填充。另外,如果您设置targets参数以匹配maximum_iterations形状,则不必填充targets