我的尺寸张量为[BATCH_SIZE,128],张力A,B都为尺寸[528,128]。我想构建一个大小为[BATCH_SIZE,528]的新张量p',其中列j定义为:
tf.reduce_prod(self.A[j,:] * p + self.B[j,:], axis=1)
我目前使用for
循环进行了强制实现,但它非常慢。有什么方法可以使用广播或其他东西来加快速度吗?
ps = []
for j in xrange(self.A.shape[0]):
a = self.A[j,:]
b = self.B[j,:]
ps.append(tf.reduce_prod(a * p + b, axis=1))
p = tf.stack(ps, axis=1)
答案 0 :(得分:1)
试试这个:
result = tf.reduce_prod(self.A * p[:, tf.newaxis] + self.B, axis=2)