我目前正在尝试了解tf.stop_gradient的工作原理,为此,我使用了这个小代码段
tf.reset_default_graph()
w1 = tf.get_variable(name = 'w1',initializer=tf.constant(10, dtype=tf.float32))
w2 = tf.get_variable(name = 'w2',initializer=tf.constant(3,dtype=tf.float32), trainable=True)
inter = w1*w2
inter=tf.stop_gradient(inter)
loss = w1*w1 - inter - 10
opt = tf.train.GradientDescentOptimizer(learning_rate = 0.0001)
gradients = opt.compute_gradients(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(gradients))
错误:TypeError:提取参数None具有无效的类型
如果我使用tf.stop_gradient注释掉该行,则代码运行正常且符合预期。请指导我如何使用tf.stop_gradient
答案 0 :(得分:1)
您正确使用了tf.stop_gradient
。但是,TensorFlow通过删除所有通往inter
的图形连接来停止loss
处的梯度。结果,如果您使用None
或dLoss/dw2
来计算tf.gradients
,则它将返回opt.compute_gradients
,因为[1]
返回
None
可以明确表明两者之间没有图形连接。
这就是TypeError
出现的方式(dLoss/dw1
没有此问题)。
许多用户(包括我自己)都认为这种梯度应该是0
而不是None
,但是TensorFlow工程师坚持认为这是预期的行为。
幸运的是,有解决方法,请尝试以下代码:
import tensorflow as tf
w1 = tf.get_variable(name='w1', initializer=tf.constant(10, dtype=tf.float32))
w2 = tf.get_variable(name='w2', initializer=tf.constant(3, dtype=tf.float32))
inter = w1 * w2
inter = tf.stop_gradient(inter)
loss = w1*w1 - inter - 10
dL_dW = tf.gradients(loss, [w1, w2])
# Replace None gradient with 0 manully
dL_dW = [tf.constant(0) if grad is None else grad for grad in dL_dW]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(dL_dW))