如何将火炬[N,1]和火炬[1,N]相乘?

时间:2020-10-26 02:34:55

标签: pytorch torch

我想通过将2个割炬矢量A(shape [N,1])与B = A'(shape [1,N])相乘来计算矩阵(shape [N,N])。

当我使用torch.matmultorch.mm时,出现错误或A'A(shape [1,1])。

如果A表示为 A = [a_1, a_2, ..., a_N]',我想计算一个矩阵C,其(i,j)元素为

for i in range(N):
     for j in range(N):
       C(i,j) = a_i * a_j

我想快速计算一下。你有什么想法? 谢谢您的帮助!

2 个答案:

答案 0 :(得分:1)

如果我对您的理解正确,则可以执行以下操作:

A = torch.randint(0,5,(5,))
C = (A.view(1, -1) * A.view(-1, 1)).to(torch.float)

它产生:

tensor([[ 1.,  4.,  3.,  3.,  0.],
        [ 4., 16., 12., 12.,  0.],
        [ 3., 12.,  9.,  9.,  0.],
        [ 3., 12.,  9.,  9.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]])

等效于书写:

 D = torch.zeros((5,5))
for i in range(5):
  for j in range(5):
    D[i][j] = A[i] * A[j]

结果为:

tensor([[ 1.,  4.,  3.,  3.,  0.],
        [ 4., 16., 12., 12.,  0.],
        [ 3., 12.,  9.,  9.,  0.],
        [ 3., 12.,  9.,  9.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]])

答案 1 :(得分:0)

您可以简单地执行以下操作:

import torch
A = torch.randint(0, 5, (3, 2))
B = torch.randint(0, 5, (2, 3))

A:

tensor([[1, 3],
        [2, 1],
        [1, 3]])

B:

tensor([[1, 0, 3],
        [3, 4, 1]])
C = A @ B # python 3.5+

C:

tensor([[10, 12,  6],
        [ 5,  4,  7],
        [10, 12,  6]])