试图理解numpy中的基本矩阵乘法

时间:2018-02-14 00:52:39

标签: python numpy

我无法理解我在numpy中遇到的错误。从线性代数我知道A.T@b.T=(b@A).T其中@是点积,而.T是转置矩阵。但是在numpy中我得到了以下错误:

a = np.arange(4).reshape(1,4)
b = np.arange(4*3*2*5).reshape(4,3,2,5)
a@b

File "<stdin>", line 1, in <module>
ValueError: shapes (1,4) and (4,3,2,5) not aligned: 4 (dim 1) != 2 (dim 2)

然而b.T@a.T工作正常。所以我的问题是:为什么(b.T@a.T).T有效,为什么a@b不起作用?

1 个答案:

答案 0 :(得分:0)

我的解释方式:

b.T @ aT is 
(5,2,3,4) @ (4,1) => (5,2,3 1)

最后一个尺寸dot和(5,2)&#39;适用于骑行&#39;:

(3,4) @ (4,1) => (3,1)

也就是说,正确的dot是最后一个维度,点缀着倒数第二个,即4个匹配

@matmul的文档为https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html#numpy.matmul

  

如果任一参数是N-D,则N> 2,它被视为驻留在最后两个索引中的一堆矩阵并相应地进行广播。

     

如果两个参数都是2-D,它们就像传统矩阵一样倍增。

在     a @ b是(1,4)@(4,3,2,5)

(1,4)与b的最后2个维度(2,5)无差别地匹配

In [158]: a@b
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-158-a832bb91e25d> in <module>()
----> 1 a@b

ValueError: shapes (1,4) and (4,3,2,5) not aligned: 4 (dim 1) != 2 (dim 2)

a的最后一个暗点是4.它与b的第二个到最后一个不匹配,即2。

在爱因斯坦符号中,ab可以点缀在&#39;用:

In [159]: np.einsum('ij,jklm->iklm',a,b).shape
Out[159]: (1, 3, 2, 5)

请注意,j维度是2个阵列的共同维度。

In [160]: np.einsum('ijkm,ml->ijkl',b.T, a.T).shape     # pairing the m's
Out[160]: (5, 2, 3, 1)