Tensorflow:为什么不进行图形成本计算

时间:2016-10-04 21:53:02

标签: python tensorflow

我有一个标准的实验循环,如下所示:

cross_entropy_target = tf.reduce_mean(tf.reduce_mean(tf.square(target_pred - target)))
cost = cross_entropy_target
opt_target = tf.train.AdamOptimizer(learning_rate=0.00001).minimize(cost)
for epoch in range(num_epochs):
    for mini_batch in range(num_samples / batch_size):
        mb_train_x, mb_train_target = get_mini_batch_stuffs()
        sess.run(opt_target, feed_dict={x: mb_train_x, target: mb_train_target})

这会运行并收敛到良好的预测损失。现在,稍微修改相同的代码:

cross_entropy_target = tf.reduce_mean(tf.reduce_mean(tf.square(target_pred - target)))
cross_entropy_target_variable = tf.Variable(0.0)
cost = cross_entropy_target_variable
opt_target = tf.train.AdamOptimizer(learning_rate=0.00001).minimize(cost)
for epoch in range(num_epochs):
    for mini_batch in range(num_samples / batch_size):
        mb_train_x, mb_train_target = get_mini_batch_stuffs()
        new_target_cost = sess.run(cross_entropy_target, feed_dict={x: mb_train_x, time: mb_train_time, target: mb_train_target})
        sess.run(tf.assign(cross_entropy_target_variable, new_target_cost))
        sess.run(opt_target, feed_dict={x: mb_train_x, target: mb_train_target})

现在,我不是将cross_entropy_target计算为opt_target图的一部分,而是预先计算它,将其分配给tensorflow变量,并期望它使用该值。这根本不起作用。网络的输出永远不会改变。

我希望这两个代码片段具有相同的结果。在这两种情况下,都会使用前馈来填充targettarget_pred的值,然后将其缩减为标量值cross_entropy_target。此标量值用于通知优化程序.minimize()上梯度更新的大小和方向。

在这个玩具示例中,我计算cross_entropy_target“out of graph”并将其分配给图表tf.Variable以便在opt_target运行中使用没有任何好处。但是,我有一个真实的用例,我的成本函数非常复杂,我无法根据Tensorflow现有的张量变换来定义它。无论哪种方式,我想了解为什么使用tf.Variable来优化器的成本是错误的使用。

有趣的奇怪可能是解决方案的副产品: 如果我设置cross_entropy_target_variable = tf.Variable(0.0, trainable=False),则运行opt_target会崩溃。它要求成本价值可以修改。实际上,在运行opt_target之前和之后打印出它的值会产生不同的值:

cross_entropy_target before = 0.345796853304 cross_entropy_target after = 0.344796866179

为什么运行minimize()会修改成本变量的值?

1 个答案:

答案 0 :(得分:1)

tf.train.AdamOptimizer(行中,它会查看costcross_entropy_target,这是tf.Variable操作,并创建一个不执行任何操作的优化程序,因为{{ 1}}并不依赖于任何变量。稍后修改cross_entropy_target目标无效,因为已经创建了优化程序。