Tensorflow:gradient_override_map不能覆盖op tf.stack的向后渐变

时间:2017-10-30 10:56:11

标签: python tensorflow

我正在尝试使用tf.stacktf.RegisterGradient修改tf.gradient_override_map op后向渐变计算机制,以下是我的代码:

import tensorflow as tf

class SynthGradBuilder(object):
    def __init__(self):
        self.num_calls = 0

    def __call__(self, x, l=1.0):
        op_name = "SynthGrad%d" % self.num_calls
        @tf.RegisterGradient(op_name)
        def _grad_synth(op, grad):
            return grad[0]

        g = tf.get_default_graph()
        with g.gradient_override_map({"stack": op_name}):
            y = tf.stack([x,x])

        self.num_calls += 1
        return y

GradSys = SynthGradBuilder()

在另一个剧本中,我写了

import tensorflow as tf
from gradient_synthesizer import GradSys

x = tf.Variable([1,2])
y = GradSys(x, l=1)
z = tf.stack([x,x])


grad = tf.gradients(y, x, grad_ys=[[tf.convert_to_tensor([3, 4]), 
                              tf.convert_to_tensor([6, 8])]])
grad_stack = tf.gradients(z, x, grad_ys=[[tf.convert_to_tensor([3, 4]), 
                              tf.convert_to_tensor([6, 8])]])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print "grad bp: ", sess.run(grad)
    print "grad_stack: ", sess.run(grad_stack)
    print "y: ", sess.run(y)

预期输出应为:

grad bp: [3,4];
grad_stack: [3+6, 4+8] = [9, 12];
y: [[1,2], [1,2]];

我从代码中得到的是:

my result

表示tf.stack的向后渐变根本没有被替换,这与我的期望相反。

我不确定这种差异是否是错误地使用" stack"作为操作tf.stack的类型字符串,我按以下方式进行了实验:

my validation

描述张量y的第一项,"堆栈:0"建议op tf.stack的注册名称是" stack",这也是它的类型字符串。所以它似乎不是"堆栈"的错误。

我无法弄清楚我的代码的原因'问题。我想知道是否有人可以帮助我。

1 个答案:

答案 0 :(得分:4)

Tl; dr:正确的代码应该是:

@tf.RegisterGradient(op_name)
def _grad_synth(op, grad):
  x, y = tf.unstack(grad)
  return [x, tf.zeros_like(y)]

g = tf.get_default_graph()
with g.gradient_override_map({"Pack": op_name}):
  y = tf.stack([x, x])

因为这是一个非常常见的问题,我想解释一下细节:

原始代码中存在两个主要问题:

  1. 错误使用gradient_override_map
  2. tf.stack的实际OP名称为Pack(不是Stack),因此您需要选择Pack而不是Stack

    `g.gradient_override_map({"Pack": op_name})`.
    

    您可能想知道我如何知道实际的OP名称?好吧,一个简单的方法是通过运行以下代码来探测GraphDef:

    with tf.Graph().as_default():
      x = tf.constant(0)
      y = tf.stack([x, x])
      print(tf.get_default_graph().as_graph_def())
    
    1. 错误的渐变功能:
    2. Pack的原始渐变是一个简单的Unpackofficial code)。在您的情况下,您仍然需要首先解压缩渐变,但只传播FIRST部分:

      @tf.RegisterGradient(op_name)
      def _grad_synth(op, grad):
        x, y = tf.unstack(grad)
        return [x, tf.zeros_like(y)]
      

      请注意,此代码适合您的情况。但是,如果要支持任何长度的堆栈,可以使用稍微复杂的版本:

      @tf.RegisterGradient(op_name)
      def _grad_synth(op, grad):
        x_list = tf.unstack(grad)
        for i in range(1, len(x_list)):
          x_list[i] = tf.zeros_like(x_list[i])
        return x_list