无法使用tf.stop_gradient

时间:2019-02-28 15:53:12

标签: tensorflow

我目前正在尝试了解tf.stop_gradient的工作原理,为此,我使用了这个小代码段

tf.reset_default_graph()
w1 = tf.get_variable(name = 'w1',initializer=tf.constant(10, dtype=tf.float32))
w2 = tf.get_variable(name = 'w2',initializer=tf.constant(3,dtype=tf.float32), trainable=True)
inter = w1*w2
inter=tf.stop_gradient(inter)
loss = w1*w1 - inter  - 10
opt = tf.train.GradientDescentOptimizer(learning_rate = 0.0001)


gradients = opt.compute_gradients(loss)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gradients))

错误:TypeError:提取参数None具有无效的类型

如果我使用tf.stop_gradient注释掉该行,则代码运行正常且符合预期。请指导我如何使用tf.stop_gradient

1 个答案:

答案 0 :(得分:1)

您正确使用了tf.stop_gradient。但是,TensorFlow通过删除所有通往inter的图形连接来停止loss处的梯度。结果,如果您使用NonedLoss/dw2来计算tf.gradients,则它将返回opt.compute_gradients,因为[1]

  

返回None可以明确表明两者之间没有图形连接。

这就是TypeError出现的方式(dLoss/dw1没有此问题)。 许多用户(包括我自己)都认为这种梯度应该是0而不是None,但是TensorFlow工程师坚持认为这是预期的行为。

幸运的是,有解决方法,请尝试以下代码:

import tensorflow as tf

w1 = tf.get_variable(name='w1', initializer=tf.constant(10, dtype=tf.float32))
w2 = tf.get_variable(name='w2', initializer=tf.constant(3, dtype=tf.float32))
inter = w1 * w2
inter = tf.stop_gradient(inter)
loss = w1*w1 - inter - 10
dL_dW = tf.gradients(loss, [w1, w2])
# Replace None gradient with 0 manully
dL_dW = [tf.constant(0) if grad is None else grad for grad in dL_dW]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(dL_dW))