用另一个函数的输出替换Tensorflow函数的渐变

时间:2017-10-09 12:33:55

标签: python tensorflow

我正在使用由复合Tensorflow操作组成的函数。但是,我不想让Tensorflow根据其中一个输入自动计算其导数,而是想在同一输入上用不同的计算替换渐变。此外,一些计算在前向和后向传递之间共享。例如:

def func(in1, in2):
    # do something with inputs using only tf operations
    shared_rep = tf.op1(tf.op2(tf.op3(in1, in2))) # same computation for both forward and gradient pass
    # return output of forward computation
    return tf.op4(shared_rep)

def func_grad(in1, in2):
    shared_rep = tf.op1(tf.op2(tf.op3(in1, in2)))
    # explicitly calculate gradients with respect to in1, with the intention of replacing the gradients computed by Tensorflow
    mygrad1 = tf.op5(tf.op6(shared_rep))
    return mygrad1

in1 = tf.Variable([1,2,3])
in2 = tf.Variable([2.5,0.01])
func_val = func(in1, in2)
my_grad1 = func_grad(in1, in2)
tf_grad1 = tf.gradients(func_val, in1)
with tf.Session() as sess:
    # would like tf_grad1 to equal my_grad1
    val, my1, tf1 = sess.run([func_val, my_grad1, tf_grad1])
    tf.assert_equal(my1, tf1)

注意:这类似于问题How to replace or modify gradient?,但有一个关键区别:我对后向传递中不同函数的Tensorflow计算梯度不感兴趣;相反,我想根据输入上的替代张量流操作自己提供渐变。

我正在尝试使用solution to the above questionthe following post中提出的想法,即使用 tf.RegisterGradient gradient_override_map 来覆盖包含前向函数的标识函数的梯度。 这失败了,因为在注册的替代标识内部,我无法访问func_grad的输入:

@tf.RegisterGradient("CustomGrad")
def alternate_identity_grad(op, grad):
    # op.inputs[0] is the output of func(in1,in2)
    # grad is of no use, because I would like to replace it with func_grad(in1,in2)

g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
    out_grad = tf.identity(input, name="Identity")

编辑经过进一步研究,我认为这个问题类似于the following question。我设法通过将 gradient_override_map 与hack suggested here结合起来获得所需的解决方案。

0 个答案:

没有答案