如何有条件地减少张量

时间:2019-09-02 20:20:55

标签: tensorflow gradient reduce

我有一个自定义渐变,该渐变计算两个张量的内积:在最后一个维度上的vector_conj_inner(a,b)。张量'b'具有广播轴,因此b的梯度必须在该轴上减小为1。

举个例子,采用以下两个张量a,b。为了弄清楚前进和后退路径中发生了什么,我在注释中写下了相应的形状。


a.shape = (3,5,7)
b.shape = (5,7)

b = tf.expand_dims(b, axis=0)                      # shape = (1,5,7)
s = vector_conj_inner(a,b)

def vector_conj_inner(a, b):
    s = tf.reduce_sum(a*tf.conj(b), -1)            # shape = (3,5)

    def grad(grad_s):
        grad_s = tf.expand_dims(grad_s, axis=-1)   # shape = (3,5,1)
        grad_a = grad_s*b                          # shape = (3,5,7)
        grad_b = tf.conj(grad_s)*a                 # shape = (3,5,7)
        grad_b = tf.reduce_sum(grad_b, axis=0, keepdims=True)   # shape = (1,5,7)
        return grad_a, grad_b

    return s, grad

当我将“ grad_b”的轴0减小为1时,可以看到形状可以计算出来。

我现在的问题是,如何使它更通用,即自动检测“ b”中的哪个轴为1,并根据该信息减少“ grad_b”?我需要的是类似于条件reduce_sum的条件,其中条件是b的轴为1。

有什么想法吗?

0 个答案:

没有答案