如何使解码器在rnn,张量流中馈入先前的输出

时间:2018-10-02 08:54:25

标签: python tensorflow seq2seq encoder-decoder

我想知道如何在tensorflow rnn中制作解码器,将其第i个输出馈送到第(i + 1)个输入

我的输入有20个序列和3680个痴呆 我的输出有39个序列和3680个痴呆 所有数据都是0〜1个数字

这是我的模特

with tf.variable_scope('encoder'):    
    enc_input = tf.placeholder(tf.float32,[None, input_sequence_length, input_dim])

    enc_cell = tf.contrib.rnn.BasicLSTMCell(num_units = input_sequence_length)

    _ , encoder_states = tf.nn.dynamic_rnn(enc_cell, enc_input ,  dtype=tf.float32)

with tf.variable_scope('decoder'):
    dec_input = tf.placeholder(tf.float32,[None, output_sequence_length, output_dim])
    dec_output = tf.placeholder(tf.float32,[None, output_sequence_length, output_dim])

    dec_cell = tf.contrib.rnn.BasicLSTMCell(num_units = output_sequence_length)

    outputs , _ = tf.nn.dynamic_rnn(dec_cell, dec_input,dtype = tf.float32,
                                    initial_state = encoder_states)

我如何制作将先前的输出馈送到下一个输入的解码器模型?

PS

我将自己的自动应答代码设置为

with tf.variable_scope('decoder'):
    dec_input = tf.placeholder(tf.float32,[None, 1, output_dim])
    dec_output = tf.placeholder(tf.float32,[None, output_sequence_length, output_dim])

    outputs = []
    state = encoder_states

    dec_cell = tf.contrib.rnn.BasicLSTMCell(num_units = dec_hidden_size)

    for i in range(output_sequence_length):
        if i==0:
            output , state = tf.nn.dynamic_rnn(dec_cell, dec_input, initial_state = state, dtype = tf.float32)
            outputs.append(output)
        else:
            output , state = tf.nn.dynamic_rnn(dec_cell, 
                                                 output, 
                                                 initial_state = state, 
                                                 dtype = tf.float32)
            outputs.append(output)

outputs = tf.reshape(outputs,[-1,output_dim])
outputs = tf.reshape(outputs,[-1,output_sequence_length,output_dim])

我认为此代码的输出与上层代码的输出不同 但我不确定它是否正常工作。

所以我仍然想知道如何使用Tensorflow方法制作具有循环功能((i)output->(i + 1)input)的解码器, 因为它需要比上层代码更多的内存分配。 (我认为它具有相同的细胞数)

0 个答案:

没有答案