如何使用广播来加速这段代码?

时间:2016-08-12 14:56:47

标签: python performance numpy numpy-broadcasting

我有以下多维数组。第一轴表示a 三维矢量。我想计算每个的3乘3矩阵x⋅x' 那些。

我目前的解决方案:

arr.shape
# (3, 64, 64, 33, 187)

dm = arr.reshape(3,-1)
dm.shape
# (3, 25276416)

cov = np.empty((3,3,dm.shape[1]))
cov.shape
# (3, 3, 25276416)

这个for循环迭代所有25,276,416个元素,大约需要1或2分钟。

for i in range(dm.shape[1]):
    cov[...,i] = dm[:,i].reshape(3,1).dot(dm[:,i].reshape(1,3))

cov = cov.reshape((3,) + arr.shape)
cov.shape
# (3, 3, 64, 64, 33, 187)

1 个答案:

答案 0 :(得分:1)

嗯,你并没有真正使用np.dot使用矩阵乘法减少任何轴,它只是在那里广播元素乘法。所以,您可以简单地使用NumPy broadcasting来完成整个事情,就像这样 -

cov = dm[:,None]*dm

或直接在arr上使用它,以避免创建dm以及所有重塑,例如 -

cov = arr[:,None]*arr