使用TensorFlow RNN增量连续时间序列

时间:2016-02-17 05:51:16

标签: sequence tensorflow recurrent-neural-network

我想提供时间序列数据 - 一次一步,逐步构建rnn(经过几个初始步骤)。

目前rnn()将encoder_input和decoder_input作为完整序列。

def rnn_seq2seq(encoder_inputs, decoder_inputs, cell, initial_state=None,output_projection=None,feed_previous=False, dtype=tf.float32, scope=None):
  with tf.variable_scope(scope or "rnn_seq2seq"):
    _, enc_states = rnn.rnn(cell, encoder_inputs, initial_state=initial_state, dtype=dtype)

    def extract_argmax(prev, i):
        with tf.device('/gpu:0'):
            prev = tf.nn.softmax(tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1]))
        return prev

    loop_function = None
    if feed_previous:
      loop_function = extract_argmax

    print enc_states[-1]
    return seq2seq.rnn_decoder(decoder_inputs, enc_states[-1], cell, loop_function=loop_function)

feedP = tf.placeholder(dtype=tf.bool)
with tf.variable_scope("mylstm"): 
    output,state = rnn_seq2seq(enc_inputs,dec_inputs,cell,output_projection=output_projection,feed_previous=feedP)

是否可以一次一步地提供decoder_input而不是整个序列,因为这是数据实时传输的方式?

1 个答案:

答案 0 :(得分:0)

结帐Scikit Flow。示例文件夹包含许多处理RNN的示例,并且您可以插入现有代码中的内置RNN估算器。

在估算工具中检查此fit()方法,您将找到允许持续培训的partial_fit(),以满足您的需求。许多示例使用此方法继续在while循环中进行训练并随时间保存检查点(您还可以配置频率)。

希望这有帮助。