我正在尝试使用tf.stack
和tf.RegisterGradient
修改tf.gradient_override_map
op后向渐变计算机制,以下是我的代码:
import tensorflow as tf
class SynthGradBuilder(object):
def __init__(self):
self.num_calls = 0
def __call__(self, x, l=1.0):
op_name = "SynthGrad%d" % self.num_calls
@tf.RegisterGradient(op_name)
def _grad_synth(op, grad):
return grad[0]
g = tf.get_default_graph()
with g.gradient_override_map({"stack": op_name}):
y = tf.stack([x,x])
self.num_calls += 1
return y
GradSys = SynthGradBuilder()
在另一个剧本中,我写了
import tensorflow as tf
from gradient_synthesizer import GradSys
x = tf.Variable([1,2])
y = GradSys(x, l=1)
z = tf.stack([x,x])
grad = tf.gradients(y, x, grad_ys=[[tf.convert_to_tensor([3, 4]),
tf.convert_to_tensor([6, 8])]])
grad_stack = tf.gradients(z, x, grad_ys=[[tf.convert_to_tensor([3, 4]),
tf.convert_to_tensor([6, 8])]])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print "grad bp: ", sess.run(grad)
print "grad_stack: ", sess.run(grad_stack)
print "y: ", sess.run(y)
预期输出应为:
grad bp: [3,4];
grad_stack: [3+6, 4+8] = [9, 12];
y: [[1,2], [1,2]];
我从代码中得到的是:
表示tf.stack
的向后渐变根本没有被替换,这与我的期望相反。
我不确定这种差异是否是错误地使用" stack"作为操作tf.stack
的类型字符串,我按以下方式进行了实验:
描述张量y的第一项,"堆栈:0"建议op tf.stack
的注册名称是" stack",这也是它的类型字符串。所以它似乎不是"堆栈"的错误。
我无法弄清楚我的代码的原因'问题。我想知道是否有人可以帮助我。
答案 0 :(得分:4)
Tl; dr:正确的代码应该是:
@tf.RegisterGradient(op_name)
def _grad_synth(op, grad):
x, y = tf.unstack(grad)
return [x, tf.zeros_like(y)]
g = tf.get_default_graph()
with g.gradient_override_map({"Pack": op_name}):
y = tf.stack([x, x])
因为这是一个非常常见的问题,我想解释一下细节:
原始代码中存在两个主要问题:
gradient_override_map
: tf.stack
的实际OP名称为Pack
(不是Stack
),因此您需要选择Pack
而不是Stack
:
`g.gradient_override_map({"Pack": op_name})`.
您可能想知道我如何知道实际的OP名称?好吧,一个简单的方法是通过运行以下代码来探测GraphDef:
with tf.Graph().as_default():
x = tf.constant(0)
y = tf.stack([x, x])
print(tf.get_default_graph().as_graph_def())
Pack
的原始渐变是一个简单的Unpack
(official code)。在您的情况下,您仍然需要首先解压缩渐变,但只传播FIRST部分:
@tf.RegisterGradient(op_name)
def _grad_synth(op, grad):
x, y = tf.unstack(grad)
return [x, tf.zeros_like(y)]
请注意,此代码适合您的情况。但是,如果要支持任何长度的堆栈,可以使用稍微复杂的版本:
@tf.RegisterGradient(op_name)
def _grad_synth(op, grad):
x_list = tf.unstack(grad)
for i in range(1, len(x_list)):
x_list[i] = tf.zeros_like(x_list[i])
return x_list