“ tf.contrib.rnn.OutputProjectionWrapper”在“ tf.contrib.rnn.BasicRNNCell”中起什么作用

时间:2019-05-03 06:57:55

标签: tensorflow lstm recurrent-neural-network

在许多实现BasicRNNCell的地方,发现使用以下代码:

tf.contrib.rnn.OutputProjectionWrapper(
    tf.contrib.rnn.BasicRNNCell(num_units= num_neurons , activation=tf.nn.relu), 
    output_size=num_outputs)

“ OutputProjectionWrapper”在“ BasicRNNCell”上做什么?

根据为实现“ tf.contrib.rnn.BasicRNNCell”而看到的代码,调用函数将返回RNN的输出。我们可以直接使用其调用函数进行操作。

# Creating the Model

num_inputs = 1
num_neurons = 100
num_outputs = 1
learning_rate = 0.005
num_train_iterations = 2000
batch_size = 1


tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, num_time_steps, num_inputs])
y = tf.placeholder(tf.float32, [None, num_time_steps, num_outputs])

# Using Basic RNN Model

cell= tf.contrib.rnn.OutputProjectionWrapper(tf.contrib.rnn.BasicRNNCell(num_units=num_neurons,activation=tf.nn.relu),output_size=num_outputs)

outputs, states = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)

# MEAN SQUARED ERROR
loss = tf.reduce_mean(tf.square(outputs - y))
optimizer = tf.train.AdagradOptimizer(learning_rate = learning_rate)
train = optimizer.minimize(loss)

我希望我们可以将BasicRNNCell直接传递给tf.nn.dynamic_rnn,但是在执行此步骤之前,OutputProjectionWrapper对我来说是完全未知的。

1 个答案:

答案 0 :(得分:0)

是的,您可以直接将BasicRNNCell传递到tf.nn.dynamic_rnn,也可以在将投影层过渡到tf.nn.dynamic_rnn之前向其添加投影层。 OutputProjectionWrapper的作用是在RNN的输出之后添加一个密集层。