PyTorch 4D和2D矩阵的广播乘法?

时间:2020-04-09 03:40:03

标签: python pytorch broadcast torch

如何广播以将这两个矩阵相乘?

x: torch.Size([10, 120, 180, 30]) # (N, H, W, C)
W: torch.Size([64, 30]) # (Y, C)

输出应为:

(10, 120, 180, 64) == (N, H, W, Y)

1 个答案:

答案 0 :(得分:1)

我假设x是带有批次的示例,而w矩阵是相应的权重。在这种情况下,您可以简单地执行以下操作:

out = x @ w.T

这是张量乘法,而不是元素方式的。您无法进行逐元素乘法来获得这种形状,并且此操作没有意义。您所能做的就是以某种方式unsqueeze对两个矩阵进行广播,并对由于某些原因而不需要的维进行一些运算:

x : torch.Size([10, 120, 180, 30, 1])
W: torch.Size([1, 1, 1, 30, 64]) # transposition would be needed as well

经过unsqueezing之后,您可以沿着第三x*w进行summeandim以获得所需的形状。

为清楚起见,这两种方式不相同