2D 乘以 2D 等于 3d pytorch 张量

时间:2021-01-19 00:58:18

标签: python pytorch matrix-multiplication

给定两个 2-D pytorch 张量:

A = torch.FloatTensor([[1,2],[3,4]])
B = torch.FloatTensor([[0,0],[1,1],[2,2]])

是否有一种有效的方法来计算形状为 (6, 2, 2) 的张量,其中每个条目是一列 A 乘以每行 B

例如,对于上面的 A 和 B,3D 张量应该具有以下矩阵:

[[[0, 0],
  [0, 0]],

 [[1, 1],
  [3, 3]],

 [[2, 2],
  [6, 6]],

 [[0, 0],
  [0, 0]],

 [[2, 2],
  [4, 4]],

 [[4, 4],
  [8, 8]]]

我知道如何通过 for 循环来完成它,但我想知道是否可以有一种有效的方法来保存它。

1 个答案:

答案 0 :(得分:1)

Pytorch 张量实现 numpy style broadcast semantics 可以解决这个问题。

从问题中不清楚您是要执行矩阵乘法还是元素乘法。在你展示的长度为 2 的情况下,两者是等价的,但对于高维肯定不是这样!谢天谢地,代码几乎相同,所以我只提供两个选项。

A = torch.FloatTensor([[1, 2], [3, 4]])
B = torch.FloatTensor([[0, 0], [1, 1], [2, 2]])

# matrix multiplication
C_mm = (A.T[:, None, :, None] @ B[None, :, None, :]).flatten(0, 1)

# element-wise multiplication
C_ew = (A.T[:, None, :, None] * B[None, :, None, :]).flatten(0, 1)

代码描述A.T 是形状 A。由于None(矩阵乘法)对张量的最后两个维度进行运算,并广播其他维度,因此结果是每列A.T[:, None, :, None]乘以(2, 1, 2, 1)的每一行的矩阵乘法。在元素方式的情况下,广播是在每个维度上执行的。结果是一个 B[None, :, None, :] 张量。要将其转换为 (1, 3, 1, 2) 张量,我们只需使用 Tensor.flatten 展平前两个维度。