按权重

时间:2017-05-05 22:27:41

标签: tensorflow

我正在训练一个网络,使张量t1之一具有以下形状:

shape(t1) = [?, 300, 300, 10]

和另一个张量t2有形状:

shape(t2) = [?, 10]

我希望张量t2的每个切片[300, 300]多个t1张量的每个元素。有谁知道怎么做?到目前为止,我写了以下内容:

def mul_concat(I):
    A = []
    for i in range(d1.shape[1].value):
        A.append(d1[:, i]*I[:, :, :, i]))
return reduce(lambda a, b: a+b, A)

但是,由于batch size维度,我收到错误。任何想法如何解决?

1 个答案:

答案 0 :(得分:0)

如果您重塑t2以塑造[?, 1, 1, 10],那么Tensorflow的广播规则将完成其余的工作:

t2_reshaped = tf.reshape(t2, [-1, 1, 1, 10])
output = t1 * t2_reshaped

许多Tensorflow运营商允许广播;广播规则与numpy广播规则相同。见https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html

希望有所帮助!