Tensorflow LSTM批量培训

时间:2017-12-24 04:38:43

标签: python tensorflow recurrent-neural-network

我编码了一个LSTM RNN,x的形状为[n_batch = 25,seq_len = 250,n_inputs = 1]。我在sin函数的片段上训练模型,并预测另一个片段。这是结果。我将预测从[n_batch = 25,seq_len = 250,n_output = 1]改为[test_len = 2000,n_ouput = 1] [LSTM prediction on sin function] 1

您会看到它在每个批次的开头都会降为0。我的代码如下:

with tf.Session() as sess:
            sess.run(init_op)
            training_mse = []
            zero_state = sess.run(initial_state)

            for epoch in range(100):
                current_state = zero_state
                feed_dict = {
                        tf_x: train_x, 
                        tf_y: train_y, 
                        initial_state : current_state, 
                }

                sess.run(self.optimize, feed_dict = feed_dict)

                cost, current_state = sess.run([error, final_state], feed_dict = feed_dict)

def forward(self):
        tf_x = tf.placeholder(tf.float32, [batch_size, seq_len, n_inputs])
        tf_y = tf.placeholder(tf.float32, [batch_size, seq_len, n_outputs])

        keep_prob = tf.placeholder(tf.float32)

        layers = [
                tf.nn.rnn_cell.LSTMCell(10)
            for _ in range(n_layers)
        ]

        cells = tf.nn.rnn_cell.MultiRNNCell(layers)

        initial_state = cells.zero_state(batch_size, tf.float32)

        outputs, final_state = tf.nn.dynamic_rnn(
                cells,
                tf_x,
                initial_state = initial_state,
                dtype = tf.float32
        )

        outputs = tf.reshape(outputs, [-1, self.params.n_neurons])
        predictions = tf.matmul(outputs, W2) + b2

我认为这是由于lstm单元格的initial_state,但不确定为什么会这样做。有什么想法吗?

0 个答案:

没有答案