TensorFlow自定义渐变

时间:2017-10-22 15:50:43

标签: tensorflow backpropagation

我有一个自定义渐变计算功能,可以将输入的渐变加倍。

import tensorflow as tf

@tf.RegisterGradient("CustomSquare")
def _custom_square_grad(op, grad):
  return grad*2.0

c = tf.constant(3.)

s1 = tf.square(c)
grad1 = tf.gradients(s1, c)[0]

g = tf.get_default_graph()
with g.gradient_override_map({"Square": "CustomSquare"}):
    s2 = tf.square(c)
    grad2 = tf.gradients(s2, c)[0]

with tf.Session() as sess:
    print(sess.run([c, s1, grad1]))
    print(sess.run([c, s2, grad2]))

我得到的结果令人惊讶:

[3.0, 9.0, 6.0]
[3.0, 9.0, 2.0]

我原以为第二个结果是[3.0, 9.0, 12.0]。我错过了什么?

感谢。

1 个答案:

答案 0 :(得分:1)

简而言之,_custom_square_grad的正确版本应为:

@tf.RegisterGradient("CustomSquare")                                             
def _custom_square_grad(op, grad):                                               
    x = op.inputs[0]                                                            
    return 2.0 * (grad * 2.0 * x)

为了理解代码,您需要了解gradient的工作原理。定义tf.RegisterGradient时,应该将渐变从输出反向传播到输入。对于tf.squre,默认的渐变函数是这样的:

# Given y = tf.square(x) => y' = 2x
grad_x = grad_y * 2.0 * x

由于您希望在自定义渐变功能中加倍渐变,因此您只需将其更改为grad_x = 2.0 * (grad_y * 2.0 * x)