在3D张量行(Tensorflow)上进行广播操作

时间:2018-04-27 05:33:26

标签: python tensorflow

我的尺寸张量为[BATCH_SIZE,128],张力A,B都为尺寸[528,128]。我想构建一个大小为[BATCH_SIZE,528]的新张量p',其中列j定义为:

tf.reduce_prod(self.A[j,:] * p + self.B[j,:], axis=1)

我目前使用for循环进行了强制实现,但它非常慢。有什么方法可以使用广播或其他东西来加快速度吗?

ps = []
for j in xrange(self.A.shape[0]):
    a = self.A[j,:]
    b = self.B[j,:]
    ps.append(tf.reduce_prod(a * p + b, axis=1))
p = tf.stack(ps, axis=1)

1 个答案:

答案 0 :(得分:1)

试试这个:

result = tf.reduce_prod(self.A * p[:, tf.newaxis] + self.B, axis=2)