Pytorch中的批量矩阵乘法-与输出尺寸的处理混淆

时间:2019-06-11 12:37:27

标签: python vectorization pytorch batch-processing matrix-multiplication

我有两个数组:

A
B

数组A包含一批RGB图像,形状为:

[batch, Width, Height, 3]

而数组B包含对图像进行“类似变换”操作所需的系数,形状为:

[batch, 4, 4, 3]

简单地说,对单个图像的运算是一个乘法,输出一个环境图(normalMap * Coefficients)。

我想要的输出应该保持形状:

[batch, Width, Height, 3]

我尝试使用torch.bmm,但失败了。这有可能吗?

1 个答案:

答案 0 :(得分:0)

我认为您需要计算PyTorch可与

一起使用
BxCxHxW : number of mini-batches, channels, height, width

格式,并且也使用matmul,因为bmm可用于3D张量。

我知道您可能会在网上找到它,但是无论如何:

batch1 = torch.randn(10, 3, 20, 10)
batch2 = torch.randn(10, 3, 10, 30)
res = torch.matmul(batch1, batch2)
res.size() # torch.Size([10, 3, 20, 30])