numpy-einsum表示法:一堆矩阵与向量栈的点积

时间:2018-10-10 02:33:31

标签: python numpy numpy-broadcasting

我想将m * m个矩阵的n-dim堆栈乘以矢量(长度m)的n-dim堆栈,以便得到的m * n数组包含矩阵和矢量的点积的结果在第n个条目中:

vec1=np.array([0,0.5,1,0.5]); vec2=np.array([2,0.5,1,0.5])
vec=np.transpose(n.stack((vec1,vec2)))
mat = np.moveaxis(n.array([[[0,1,2,3],[0,1,2,3],[0,1,2,3],[0,1,2,3]],[[-1,2.,0,1.],[0,0,-1,2.],[0,1,-1,2.],[1,0.1,1,1]]]),0,2)
outvec=np.zeros((4,2))
for i in range(2):
    outvec[:,i]=np.dot(mat[:,:,i],vec[:,i])

受此帖子Element wise dot product of matrices and vectors的启发,我尝试了einsum中所有不同的索引组合扰动,并发现

np.einsum('ijk,jk->ik',mat,vec)

给出正确的结果。

不幸的是,我真的不明白这一点-我认为我在'ijk,jk'部分重复输入k意味着我将AND相加并乘以k。我试图阅读文档https://docs.scipy.org/doc/numpy-1.15.1/reference/generated/numpy.einsum.html,但我仍然不明白。

(包括我以前的尝试,

 np.einsum('ijk,il->ik', mat, vec)

我什至不知道这意味着什么。当我删除索引l时会发生什么?)

谢谢!

3 个答案:

答案 0 :(得分:2)

继续阅读Einstein summation notation

基本上,规则是:

没有->

  • 输入中重复的任何字母表示要倍增和求和的轴
  • 输入中未重复的任何字母都包含在输出中

使用->

  • 输入中重复的任何字母表示要乘以的轴
  • 输出中没有的任何字母表示要求和的轴

例如,具有相同形状的矩阵AB

np.einsum('ij, ij',       A, B)  # is A ddot B,                returns 0d scalar
np.einsum('ij, jk',       A, B)  # is A dot  B,                returns 2d tensor
np.einsum('ij, kl',       A, B)  # is outer(A, B),             returns 4d tensor
np.einsum('ji, jk, kl',   A, B)  # is A.T @ B @ A,             returns 2d tensor
np.einsum('ij, ij -> ij', A, B)  # is A * B,                   returns 2d tensor
np.einsum('ij, ij -> i' , A, A)  # is norm(A, axis = 1),       returns 1d tensor
np.einsum('ii'             , A)  # is tr(A),                   returns 0d scalar

答案 1 :(得分:1)

In [321]: vec1=np.array([0,0.5,1,0.5]); vec2=np.array([2,0.5,1,0.5])
     ...: vec=np.transpose(np.stack((vec1,vec2)))
In [322]: vec1.shape
Out[322]: (4,)
In [323]: vec.shape
Out[323]: (4, 2)

关于stack函数的一件好事是我们可以指定一条轴,跳过移调:

In [324]: np.stack((vec1,vec2), axis=1).shape
Out[324]: (4, 2)

为什么np.n.混合使用? NameError: name 'n' is not defined。这种事情几乎把我送走了。

In [326]: mat = np.moveaxis(np.array([[[0,1,2,3],[0,1,2,3],[0,1,2,3],[0,1,2,3]],[[-1,2.,0
     ...: ,1.],[0,0,-1,2.],[0,1,-1,2.],[1,0.1,1,1]]]),0,2)
In [327]: mat.shape
Out[327]: (4, 4, 2)

In [328]: outvec=np.zeros((4,2))
     ...: for i in range(2):
     ...:     outvec[:,i]=np.dot(mat[:,:,i],vec[:,i])
     ...:     
In [329]: outvec
Out[329]: 
array([[ 4.  , -0.5 ],
       [ 4.  ,  0.  ],
       [ 4.  ,  0.5 ],
       [ 4.  ,  3.55]])

In [330]: # (4,4,2) (4,2)   'kji,ji->ki'

在循环中,i轴(大小2)的位置很清楚-在所有3个数组中都位于最后。为vec留下一个轴,让我们称之为j。它与末尾(i的{​​{1}}之后)配对。 matk转移到mat

outvec

通常In [331]: np.einsum('kji,ji->ki', mat, vec) Out[331]: array([[ 4. , -0.5 ], [ 4. , 0. ], [ 4. , 0.5 ], [ 4. , 3.55]]) 字符串会自行写入。例如,如果einsum被描述为(m,n,k),而mat被描述为(n,k),则结果为(m,k)

在这种情况下,仅对vec维进行求和-它显示在左侧,但显示在右侧。我的符号中的最后一个维度j未被累加,因为如果出现在两面,就像在迭代中一样。我认为这是“一起骑”。它不是i产品的积极组成部分。

实际上,您正在堆叠最后一个尺寸为2的尺寸。通常,我们将堆栈放在第一个堆栈上,但是您将两个堆栈都放在最后一个堆栈上。


您的“失败”尝试运行,并且可以复制为:

dot

In [332]: np.einsum('ijk,il->ik', mat, vec) Out[332]: array([[12. , 4. ], [ 6. , 1. ], [12. , 4. ], [ 6. , 3.1]]) In [333]: mat.sum(axis=1)*vec.sum(axis=1)[:,None] Out[333]: array([[12. , 4. ], [ 6. , 1. ], [12. , 4. ], [ 6. , 3.1]]) j尺寸未显示在右侧,因此将它们相加。可以在相乘之前对它们进行求和,因为它们每个仅出现一次。我添加了l以启用广播(将Noneik相乘)。

i

如果您堆叠在第一个上,并为np.einsum('ik,i->ik', mat.sum(axis=1), vec.sum(axis=1)) (2,4,1)添加尺寸,它将vec带有(2,4,4)垫。 matmul

mat @ vec[...,None]

答案 2 :(得分:1)

einsum很简单(当您使用索引排列播放了一段时间时,就是...)。

让我们使用简单的东西工作,三叠的 2×2 矩阵和三叠的 2×,数组

import numpy as np

a = np.arange(3*2*2).reshape((3,2,2))
b = np.arange(3*2).reshape((3,2))

我们需要知道我们将使用einsum

计算什么
In [101]: for i in range(3): 
     ...:     print(a[i]@b[i])                                                                            
[1 3]
[23 33]
[77 95]

我们做了什么?我们有一个索引i,当我们在其中一个堆叠矩阵和一个堆叠向量(都由i索引)之间执行点积时,该索引是固定的,并且单个输出线暗示着对堆积矩阵的最后一个索引和堆积向量的孤立索引。

这很容易用einsum指令进行编码

  • 我们希望使用相同的i索引来指定矩阵,向量以及输出,
  • 我们要沿着最后一个矩阵索引和其余的矢量索引(例如k
  • 进行缩减)
  • 我们希望输出中的列与每个堆叠矩阵中的行一样多,例如j

因此

In [102]: np.einsum('ijk,ik->ij', a, b)                                                                   
Out[102]: 
array([[ 1,  3],
       [23, 33],
       [77, 95]])

我希望我对如何正确使用指令的讨论是清晰,正确和有用的。