为什么tf.stop_gradient不能冻结变量?

时间:2017-09-11 14:43:38

标签: tensorflow

查看代码:

import tensorflow as tf

x = tf.Variable(1)
x = tf.stop_gradient(x)
y = 2 * x
g = tf.gradients(y, x)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run()
    print(sess.run(g))

我想要冻结xy wrt x的渐变为零,但输出为2,那有什么不对?

更新

import tensorflow as tf

x0 = tf.Variable(1)
x1 = tf.stop_gradient(x0)
y = 2 * x1
g = tf.gradients(y, x0)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run()
    print(sess.run(g))

我使用的x1不会覆盖x0,但是当我运行此代码时,会引发错误。如果tf.stop_gradient充当tf.identity,我认为x0将在计算图中显示y的路径,渐变为0而不是引发错误?有人可以告诉我tf.stop_gradient确实做了什么吗?

1 个答案:

答案 0 :(得分:-1)

tf.stop_gradient()停止计算图中指定点的梯度计算。将它应用于变量是太迟了#34;但您可以将其应用于y

import tensorflow as tf

x = tf.Variable(1)
y = tf.stop_gradient(2 * x)
g = tf.gradients(y, x)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run()
    print(sess.run(g))

请注意,这不会输出0,而是会抛出错误,因为y w.r.t的渐变在这种情况下,x未定义,您明确要求它。在现实情况下,情况可能并非如此。