我有一个自定义渐变,该渐变计算两个张量的内积:在最后一个维度上的vector_conj_inner(a,b)。张量'b'具有广播轴,因此b的梯度必须在该轴上减小为1。
举个例子,采用以下两个张量a,b。为了弄清楚前进和后退路径中发生了什么,我在注释中写下了相应的形状。
a.shape = (3,5,7)
b.shape = (5,7)
b = tf.expand_dims(b, axis=0) # shape = (1,5,7)
s = vector_conj_inner(a,b)
def vector_conj_inner(a, b):
s = tf.reduce_sum(a*tf.conj(b), -1) # shape = (3,5)
def grad(grad_s):
grad_s = tf.expand_dims(grad_s, axis=-1) # shape = (3,5,1)
grad_a = grad_s*b # shape = (3,5,7)
grad_b = tf.conj(grad_s)*a # shape = (3,5,7)
grad_b = tf.reduce_sum(grad_b, axis=0, keepdims=True) # shape = (1,5,7)
return grad_a, grad_b
return s, grad
当我将“ grad_b”的轴0减小为1时,可以看到形状可以计算出来。
我现在的问题是,如何使它更通用,即自动检测“ b”中的哪个轴为1,并根据该信息减少“ grad_b”?我需要的是类似于条件reduce_sum的条件,其中条件是b的轴为1。
有什么想法吗?