列 2D 张量乘以行 2D 张量等于 3d pytorch 张量

时间:2021-01-19 10:44:15

标签: python pytorch

给定 PyTorch A (a X m) 和 B (m X b) 中的 2 个张量 2-D,是否有任何有效的方法来获得张量 C (m X a X b),其中C[i,:,:] = A[:,i] @ B[i,:]

这里我举一个例子:

A = torch.FloatTensor([[1,2],[3,4]])
B = torch.FloatTensor([[1,2,3],[4,5,6]])

结果:

C = torch.FloatTensor([[[1,2,3],[3,6,9]],[[12,15,18],[16,20,24]]])

我使用 for 循环完成了它。然而,它非常低效。

1 个答案:

答案 0 :(得分:0)

看看torch.einsum

C = torch.einsum('im,mj->mij', A, B)