使RNN模型没有tf.nn.dynamic_rnn,tf.nn.basic_cell,ProjectionWithoutWrapper但模型不起作用

时间:2019-03-14 03:43:43

标签: tensorflow recurrent-neural-network

这是我的RNN代码 该模型预测未来价值

例如
输入为

x = [x1, x2, x3, x4],

模型的输出为

pred = [x2, x3, x4, x5] 

using high level api(例如,包装纸)在工作时
但是当我don't use High Level API时,它不起作用
但我不知道为什么...
那么,请您检查一下我的代码?

整个代码都已链接。

感谢您掌握:)

我认为下面的一段代码很重要。

x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_steps, n_inputs])

cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons, activation=tf.nn.relu)
outputs, states = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)
stacked_outputs = tf.reshape(outputs, shape=[-1, n_neurons])
stacked_logits = tf.layers.dense(stacked_outputs, n_outputs)
outputs = tf.reshape(stacked_logits, [-1, n_steps, n_outputs])

我这样更改此代码

# Model
x = tf.placeholder(tf.float32, [None, n_steps])
x_res = tf.reshape(x, shape=[-1, n_steps, n_inputs])
x_trpose = tf.transpose(x_res , perm=(1, 0, 2))
xs_seq = tf.unstack(x_trpose)

y = tf.placeholder(tf.float32, [None, n_steps])
y_res = tf.reshape(y, shape=[-1, n_steps, n_inputs])

init_wx = tf.random_normal(shape=[n_inputs, n_neurons], dtype=tf.float32)
wx = tf.Variable(init_wx)
init_wh = tf.random_normal(shape=[n_neurons, n_neurons], dtype=tf.float32)
wh = tf.Variable(init_wh)
init_b = tf.constant(value=0, shape=[n_neurons], dtype=tf.float32)
b = tf.Variable(init_b)

# init hidden , have to input 0,
init_hidden = tf.placeholder(shape=[None, n_neurons], dtype=tf.float32)
hidden_state = tf.zeros_like(init_hidden, dtype=tf.float32)

activation = tf.nn.relu
output_layers = []
logits_list = []

for i, x_seq in enumerate(xs_seq):

    # inputs
    hidden_layer = tf.matmul(hidden_state, wh)
    input_layer = tf.matmul(x_seq, wx)
    output_layer = activation(hidden_layer + input_layer + b)
    output_layers.append(output_layer)
    hidden_state = output_layer

stacked_outputs = tf.reshape(output_layers, shape=[-1, n_neurons])
stacked_logits = tf.layers.dense(stacked_outputs, n_outputs)
outputs = tf.reshape(stacked_logits, [-1, n_steps, n_outputs])
print(outputs)

0 个答案:

没有答案