带2D元素的Numpy矩阵乘法

时间:2018-08-30 09:33:04

标签: python numpy matrix-multiplication

我有一个a numpy ndarray 3x3矩阵,看起来像这样

a =  ([[ uu, uv, uw],
       [ uv, vv, vw],
       [ uw, vw, ww]])

每个组件本身都是大小为(N,M)的2D数组,因此a矩阵的形状为(3,3,N,M)

如何以Python方式执行a*a的矩阵乘法? 使用a@a会引发以下错误(对于N = 1218和M = 540):

  

ValueError:形状(3,3,1218,540)和(3,3,1218,540)不对齐:540   (第3维)!= 1218(第2维)

我希望能够执行此操作,就好像a的元素只是标量值,而a@a不会引发与其形状相关的错误,因为它是简单的3x3矩阵乘法。

谢谢。

1 个答案:

答案 0 :(得分:1)

假设您要对最后两个轴上的每个元素执行矩阵乘法,我们可以使用np.einsum-

np.einsum('ijkl,jmkl->imkl',a,a)

运行示例以进行验证-

In [43]: np.random.seed(0)

In [44]: a = np.random.rand(3,3,4,5)

In [45]: a[:,:,0,0].dot(a[:,:,0,0])
Out[45]: 
array([[0.71750146, 1.17057872, 1.11135764],
       [0.62938365, 0.86437796, 0.74541383],
       [1.04636618, 1.62011127, 1.35483565]])

In [46]: np.einsum('ijkl,jmkl->imkl',a,a)[:,:,0,0]
Out[46]: 
array([[0.71750146, 1.17057872, 1.11135764],
       [0.62938365, 0.86437796, 0.74541383],
       [1.04636618, 1.62011127, 1.35483565]])