在tensorflow代码示例中

时间:2017-08-03 23:53:30

标签: python tensorflow lstm recurrent-neural-network rnn

为什么在任何训练迭代发生之前计算pred变量?我希望在每次迭代的每次数据传递过程中,都会生成一个pred(通过RNN()函数)?

一定有我遗失的东西。 pred类似于函数对象吗?我查看了tf.matmul()的文档,它返回张量,而不是函数。

完整来源:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/recurrent_network.py

以下是代码:

def RNN(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input)
    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)

    # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    x = tf.unstack(x, n_steps, 1)

    # Define a lstm cell with tensorflow
    lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)

    # Get lstm cell output
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(outputs[-1], weights['out']) + biases['out']

pred = RNN(x, weights, biases)

# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initializing the variables
init = tf.global_variables_initializer()

1 个答案:

答案 0 :(得分:0)

Tensorflow代码有两个不同的阶段。首先,构建一个"依赖图",其中包含您将使用的所有操作。请注意,在此阶段您不处理任何数据。相反,您只是定义要发生的操作。 Tensorflow注意到操作之间的依赖关系。

例如,为了计算accuracy,您需要先计算correct_pred,然后计算correct_pred,您需要先计算pred {1}},等等。

所以你在所显示的代码中所做的就是告诉tensorflow你想要什么操作。您已将这些保存在"图表中。数据结构(这是一个张量流数据结构,基本上是一个包含所有数学运算和张量的桶)。

稍后,您将使用对sess.run([ops], feed_dict={inputs})的调用对数据执行操作。

当您致电sess.run时,请注意您必须从图表中告诉您想要的内容。如果您要求accuracy

   sess.run(accuracy, feed_dict={inputs})

Tensorflow将尝试计算准确性。它会看到accuracy依赖于correct_pred,因此它会尝试通过您定义的依赖关系图来计算它,等等。

您所犯的错误是您认为列出的代码中的pred正在计算某些内容。不是。这一行:

   pred = RNN(x, weights, biases)

仅定义了操作及其依赖项。