假设我在一个MxMxN阵列中有N个MxM阵列。是否有任何简单的方法在numpy中进行连续MxM数组的累积矩阵乘法(可以覆盖MxMxN数组)。我可以用如下所示的循环来完成它,但我想知道是否有更好的方法?请注意,订购MxMxN并不特别,我可以轻松拥有NxMxM或其他东西。
import numpy as np
a = np.arange(4).reshape((2,2))
n=3
b = np.dstack((a,)*n)
print(b[:,:,0])
#[b[:,:,k].dot(b[:,:, k - 1], out=b[:,:, k]) for k in range(1, n)]
for k in range(1, n):
b[:,:,k] = np.dot(b[:,:,k], b[:,:,k-1])
print(b[:, :, k])
我从中得到输出:
[[0 1]
[2 3]]
[[ 2 3]
[ 6 11]]
[[ 6 11]
[22 39]]
我还尝试了以下列表理解失败:
[b[:,:,k].dot(b[:,:, k - 1], out=b[:,:, k]) for k in range(1, n)]
修改 我讨论了所有中间结果,所以b [:,:,0],b [:,:0] xb [0:,:,1],b [:,:0] xb [0 :,:,1] xb [:,:,2]等不仅仅是最后的b [:,:,0] xb [0:,:,1] x ... xb [:,:, - N 1]
答案 0 :(得分:2)
对于M * M * N阵列,如何:
reduce(np.dot, np.rollaxis(b, 2))
对于Python 3,您需要从reduce
导入functools
。