如何使用张量流的seq2seq预测简单序列?

时间:2016-06-27 09:28:20

标签: python tensorflow lstm

我最近开始使用tensorflow,所以我仍然在努力学习基础知识。

我想创建简单的seq2seq预测。

  • 输入是0到1之间的数字列表。
  • 输出是第一个数字 列表和其余的数字乘以第一个。

我设法评估模型性能并优化权重。 我一直在努力的是如何用训练有素的模型进行预测。

 model_outputs, states = seq2seq.basic_rnn_seq2seq(encoder_inputs,
                                                  decoder_inputs,
                                                  rnn_cell.BasicLSTMCell(data_point_dim, state_is_tuple=True))

为了生成model_outputs,我需要模型的输入和输出值,这有利于评估,但在预测中我只有输入值。我猜我需要对状态做一些事情,但我不确定如何将它们转换成浮点序列。

此处提供完整代码 https://gist.github.com/anonymous/be405097927758acca158666854600a2

3 个答案:

答案 0 :(得分:4)

当你训练时,你在每个解码器的时间步长给出解码器输入作为所需的输出。 测试时,您没有所需的输出,因此您可以做的最好的事情是对输出进行采样。这将是下一个时间步的输入。

TLDR;在每个时间步输入解码器输出作为下一个时间步的输入。

编辑:某些TF代码

basic_rnn_seq2seq 函数返回 s rnn_decoder (decoder_inputs,enc_states [-1],cell)

让我们看一下 rnn_decoder : def rnn_decoder(decoder_inputs,initial_state,cell,loop_function = ,                 范围=无):   ....

loop_function :如果不是None,此函数将应用于第i个输出       为了生成第i + 1个输入,并且将忽略decoder_inputs,       除了第一个元素(“GO”符号)。这可以用于解码,       还有培训模仿http://arxiv.org/pdf/1506.03099v2.pdf

在解码过程中,您需要设置此loop_function = True

我建议查看Tensorflow seq2seq库中的translate.py文件,了解如何处理。

答案 1 :(得分:0)

user4383691的上一个答案不完整。 我有同样的问题,在深入研究rnn_decoder后,发现:模型 loop_fn 应用于第i个输出,因此 True 使没有意义,因为它不是一个功能。 您应该创建一个可以接收第i个输出并返回第i + 1个输出的函数。我仍然在制作这样的功能,并且会在完成后立即更新。

答案 2 :(得分:0)

让我们看一下source code

prev = None    for i, inp in enumerate(decoder_inputs):
     if loop_function is not None and prev is not None:
       with variable_scope.variable_scope("loop_function", reuse=True):
         inp = loop_function(prev, i)
     if i > 0:
       variable_scope.get_variable_scope().reuse_variables()
     output, state = cell(inp, state)
     outputs.append(output)
     if loop_function is not None:
       prev = output

循环枚举decoder_inputs,无论您是使用提供的decoder_input进行训练还是在没有输入的情况下进行测试。这是因为在测试时,decoder_inputs被loop_function的输出(在上面代码片段的第四行中)替换。

通常,您可以使用类似here的end_ids填充dec_inputs。

  while len(dec_inputs) < self._hps.dec_timesteps:
    dec_inputs.append(end_id)
  while len(targets) < self._hps.dec_timesteps:
    targets.append(end_id)