了解PyTorch einsum

时间:2019-04-28 21:23:57

标签: python numpy pytorch numpy-einsum

我熟悉einsum在NumPy中的工作方式。 PyTorch也提供了类似的功能:torch.einsum()。在功能或性能上有何异同? PyTorch文档中提供的信息很少,并且没有提供有关此方面的任何见解。

1 个答案:

答案 0 :(得分:1)

由于在火炬文档中对einsum的描述是轻率的,所以我决定将这篇文章写成文档,比较torch.einsum()numpy.einsum()的行为并进行对比。

差异:

  • NumPy允许“下标字符串”使用小写字母和大写字母[a-zA-Z],而PyTorch只允许使用小写字母[a-z]
  • 除了optimize之外,NumPy还支持许多关键字参数(例如nd-arrays),而PyTorch则不提供这种灵活性

以下是PyTorch和NumPy中一些示例的实现:

# input tensors

In [16]: vec
Out[16]: tensor([0, 1, 2, 3])

In [17]: aten
Out[17]: 
tensor([[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]])

In [18]: bten
Out[18]: 
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]])

1)矩阵乘法
     PyTorch:torch.matmul(aten, bten)aten.mm(bten)
     NumPy:np.einsum("ij, jk -> ik", arr1, arr2)

In [19]: torch.einsum('ij, jk -> ik', aten, bten)
Out[19]: 
tensor([[130, 130, 130, 130],
        [230, 230, 230, 230],
        [330, 330, 330, 330],
        [430, 430, 430, 430]])

2)沿主对角线提取元素
   PyTorch:torch.diag(aten)
   NumPy:np.einsum("ii -> i", arr)

In [28]: torch.einsum('ii -> i', aten)
Out[28]: tensor([11, 22, 33, 44])

3)Hadamard乘积(即两个张量的按元素乘积)
   PyTorch:aten * bten
   NumPy:np.einsum("ij, ij -> ij", arr1, arr2)

In [34]: torch.einsum('ij, ij -> ij', aten, bten)
Out[34]: 
tensor([[ 11,  12,  13,  14],
        [ 42,  44,  46,  48],
        [ 93,  96,  99, 102],
        [164, 168, 172, 176]])

4)逐元素平方
   PyTorch:aten ** 2
   NumPy:np.einsum("ij, ij -> ij", arr, arr)

In [37]: torch.einsum('ij, ij -> ij', aten, aten)
Out[37]: 
tensor([[ 121,  144,  169,  196],
        [ 441,  484,  529,  576],
        [ 961, 1024, 1089, 1156],
        [1681, 1764, 1849, 1936]])

常规 :可以通过重复下标字符串和张量nth次来实现元素级n的幂。 例如,可以使用以下方法来计算张量的元素方四次方:

# NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr)
In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten)
Out[38]: 
tensor([[  14641,   20736,   28561,   38416],
        [ 194481,  234256,  279841,  331776],
        [ 923521, 1048576, 1185921, 1336336],
        [2825761, 3111696, 3418801, 3748096]])

5)痕迹(即主对角元素的总和)
   PyTorch:torch.trace(aten)
   NumPy einsum:np.einsum("ii -> ", arr)

In [44]: torch.einsum('ii -> ', aten)
Out[44]: tensor(110)

6)矩阵转置
   PyTorch:torch.transpose(aten, 1, 0)
   NumPy einsum:np.einsum("ij -> ji", arr)

In [58]: torch.einsum('ij -> ji', aten)
Out[58]: 
tensor([[11, 21, 31, 41],
        [12, 22, 32, 42],
        [13, 23, 33, 43],
        [14, 24, 34, 44]])

7)(向量的)外部乘积
   PyTorch:torch.ger(vec, vec)
   NumPy einsum:np.einsum("i, j -> ij", vec, vec)

In [73]: torch.einsum('i, j -> ij', vec, vec)
Out[73]: 
tensor([[0, 0, 0, 0],
        [0, 1, 2, 3],
        [0, 2, 4, 6],
        [0, 3, 6, 9]])

8)(向量的)内积    PyTorch:torch.ger(vec1, vec2)
   NumPy einsum:np.einsum("i, i -> ", vec1, vec2)

In [76]: torch.einsum('i, i -> ', vec, vec)
Out[76]: tensor(14)

9)沿轴0求和
   PyTorch:torch.sum(aten, 0)
   NumPy einsum:np.einsum("ij -> j", arr)

In [85]: torch.einsum('ij -> j', aten)
Out[85]: tensor([104, 108, 112, 116])

10)沿轴1求和
    PyTorch:torch.sum(aten, 1)
    NumPy einsum:np.einsum("ij -> i", arr)

In [86]: torch.einsum('ij -> i', aten)
Out[86]: tensor([ 50,  90, 130, 170])

11)批矩阵乘法
    PyTorch:torch.bmm(batch_ten, batch_ten)
    NumPy:np.einsum("bij, bjk -> bik", batch_ten, batch_ten)

In [90]: batch_ten = torch.stack((aten, bten), dim=0)
In [91]: batch_ten
Out[91]: 
tensor([[[11, 12, 13, 14],
         [21, 22, 23, 24],
         [31, 32, 33, 34],
         [41, 42, 43, 44]],

        [[ 1,  1,  1,  1],
         [ 2,  2,  2,  2],
         [ 3,  3,  3,  3],
         [ 4,  4,  4,  4]]])

In [92]: batch_ten.shape
Out[92]: torch.Size([2, 4, 4])

# batch matrix multiply using einsum
In [96]: torch.einsum("bij, bjk -> bik", batch_ten, batch_ten)
Out[96]: 
tensor([[[1350, 1400, 1450, 1500],
         [2390, 2480, 2570, 2660],
         [3430, 3560, 3690, 3820],
         [4470, 4640, 4810, 4980]],

        [[  10,   10,   10,   10],
         [  20,   20,   20,   20],
         [  30,   30,   30,   30],
         [  40,   40,   40,   40]]])

12)沿第2轴求和
    PyTorch:torch.sum(batch_ten, 2)
    NumPy einsum:np.einsum("ijk -> ij", arr3D)

In [99]: torch.einsum("ijk -> ij", batch_ten)
Out[99]: 
tensor([[ 50,  90, 130, 170],
        [  4,   8,  12,  16]])

13)对nD张量中的所有元素求和
    PyTorch:torch.sum(batch_ten)
    NumPy einsum:np.einsum("ijk -> ", arr3D)

In [101]: torch.einsum("ijk -> ", batch_ten)
Out[101]: tensor(480)

14)多轴求和(即边缘化)
    PyTorch:torch.sum(arr, dim=(dim0, dim1, dim2, dim3, dim4, dim6, dim7))
    NumPy:np.einsum("ijklmnop -> n", nDarr)

# 8D tensor
In [103]: nDten = torch.randn((3,5,4,6,8,2,7,9))
In [104]: nDten.shape
Out[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])

# marginalize out dimension 5 (i.e. "n" here)
In [111]: esum = torch.einsum("ijklmnop -> n", nDten)
In [112]: esum
Out[112]: tensor([  98.6921, -206.0575])

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))

In [115]: torch.allclose(tsum, esum)
Out[115]: True

15)双点产品(与torch.sum(hadamard-product)相同(比照3)
    PyTorch:torch.sum(aten * bten)
    NumPy:np.einsum("ij, ij -> ", arr1, arr2)

In [120]: torch.einsum("ij, ij -> ", aten, bten)
Out[120]: tensor(1300)