查看代码:
import tensorflow as tf
x = tf.Variable(1)
x = tf.stop_gradient(x)
y = 2 * x
g = tf.gradients(y, x)
init = tf.global_variables_initializer()
with tf.Session() as sess:
init.run()
print(sess.run(g))
我想要冻结x
,y
wrt x
的渐变为零,但输出为2
,那有什么不对?
import tensorflow as tf
x0 = tf.Variable(1)
x1 = tf.stop_gradient(x0)
y = 2 * x1
g = tf.gradients(y, x0)
init = tf.global_variables_initializer()
with tf.Session() as sess:
init.run()
print(sess.run(g))
我使用的x1
不会覆盖x0
,但是当我运行此代码时,会引发错误。如果tf.stop_gradient
充当tf.identity
,我认为x0
将在计算图中显示y
的路径,渐变为0
而不是引发错误?有人可以告诉我tf.stop_gradient
确实做了什么吗?
答案 0 :(得分:-1)
tf.stop_gradient()
停止计算图中指定点的梯度计算。将它应用于变量是太迟了#34;但您可以将其应用于y
。
import tensorflow as tf
x = tf.Variable(1)
y = tf.stop_gradient(2 * x)
g = tf.gradients(y, x)
init = tf.global_variables_initializer()
with tf.Session() as sess:
init.run()
print(sess.run(g))
请注意,这不会输出0
,而是会抛出错误,因为y w.r.t的渐变在这种情况下,x未定义,您明确要求它。在现实情况下,情况可能并非如此。