“ @”用于使用pytorch进行张量乘法

时间:2020-09-18 05:28:35

标签: python pytorch

本文https://towardsdatascience.com/understand-kaiming-initialization-and-implementation-detail-in-pytorch-f7aa967e9138中有关智能加权初始化的语法如下:

x@w

表示张量(/矩阵)乘法。我以前没看过,而是假设我们需要将其“拼写”为:

 torch.mm(x, w.t())

使用以前的(更精细)的语法需要什么?那篇文章没有显示他们正在使用的完整导入。

2 个答案:

答案 0 :(得分:2)

它不需要任何东西。只需import torch就足够了(两个操作数必须是张量)。例如,我尝试过

import torch
a = torch.randn((2, 2)) # tensor([[-0.3023, -1.3499], [-2.5096, -0.8977]])
b = torch.randn((2, 3)) # tensor([[-1.3319,  2.2378, -0.1892], [-0.3895, -0.5334, -0.5148]])
a@b

结果是

tensor([[ 0.9284,  0.0436,  0.7521],
        [ 3.6921, -5.1372,  0.9370]])

为了验证,我也做了

torch.matmul(a, b)

结果和以前一样

tensor([[ 0.9284,  0.0436,  0.7521],
        [ 3.6921, -5.1372,  0.9370]])

要注意的另一件事是,NumPy对于矩阵乘法也具有相同的@运算符(PyTorch通常尝试使用张量复制类似的行为,就像NumPy对其数组所做的那样)。

答案 1 :(得分:1)

仅Python 3.5及更高版本可以使用此“ @”语法。以下是等效的:

a = torch.rand(2,2)
b = torch.rand(2,2)

c = a.mm(b)
print(c)

c = torch.mm(a, b)
print(c)

c = torch.matmul(a, b)
print(c)

c = a @ b # python > 3.5+
print(c)

输出:

tensor([[0.2675, 0.8140],
        [0.0415, 0.1644]])
tensor([[0.2675, 0.8140],
        [0.0415, 0.1644]])
tensor([[0.2675, 0.8140],
        [0.0415, 0.1644]])
tensor([[0.2675, 0.8140],
        [0.0415, 0.1644]])

我喜欢使用mm语法进行矩阵到矩阵的乘法,使用mv进行矩阵到矢量的乘法。

要获得转置矩阵,我喜欢使用简单的a.T语法。

要添加的另一件事:

a = torch.rand(2,2,2)
b = torch.rand(2,2,2)

c = torch.matmul(a, b)
print(c)

c = a @ b # python > 3.5+
print(c)

输出:

tensor([[[0.2951, 0.3021],
         [0.8663, 1.0430]],

        [[0.2674, 1.3792],
         [0.0895, 0.9703]]])
tensor([[[0.2951, 0.3021],
         [0.8663, 1.0430]],

        [[0.2674, 1.3792],
         [0.0895, 0.9703]]])

mm不适用于等级> 2(等级3或更高的张量)。因此,如果您使用更大的等级,则只需使用以下符号:matmul@