pytorch中的conv2d函数

时间:2019-05-05 18:13:02

标签: python pytorch convolution

我正在尝试使用Pytorch的函数torch.conv2d,但无法获得我理解的结果...

这是一个简单的示例,其中内核(filt)与输入(im)的大小相同,以解释我要查找的内容。

import pytorch

filt = torch.rand(3, 3)
im = torch.rand(3, 3)

我想计算一个没有填充的简单卷积,所以结果应该是标量(即1x1张量)。

我尝试过conv2d

# I have to convert image and kernel to 4 dimensions tensors to use conv2d
im_torch = im.reshape((im_height, filt_height, 1, 1))
filt_torch = filt.reshape((filt_height, im_height, 1, 1))
out = torch.nn.functional.conv2d(im_torch, filt_torch, stride=1, padding=0)
print(out)

但是结果不是我所期望的:

tensor([[[[0.6067]], [[0.3564]], [[0.5397]]],
    [[[0.2557]], [[0.0493]], [[0.2562]]],
    [[[0.6067]], [[0.3564]], [[0.5397]]]])

要想了解我想要的东西,我想重现粗俗的convolve2d行为:

import scipy.signal
out_scipy = scipy.signal.convolve2d(im.detach().numpy(), filt.detach().numpy(), 'valid')
print(out_scipy)

打印:

array([[1.195723]], dtype=float32)

2 个答案:

答案 0 :(得分:2)

输入和过滤器的张量形状应为:

(batch, dim_ch, width, height)

而非:

(width, height, 1, 1)

例如

import torch
import torch.nn.functional as F
x = torch.randn(1,1,4,4);
y = torch.randn(1,1,4,4);
z = F.conv2d(x,y);

z的输出形状:

torch.Size([1,1,1,1])

答案 1 :(得分:0)

好吧,我没有找到问题的确切答案(即如何使用conv2d),但是我找到了另一种解决方法。

首先,我了解到我正在寻找的称为 valid 互相关,它实际上是[Conv2d][1]类实现的操作。

因此,我的解决方案使用Conv2d类而不是conv2d函数。

import pytorch

img = torch.rand(3, 3)

model = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 3), stride=1, padding=0, bias=False)

res = conv_mdl(img)
print(res.shape)

打印我想要的标量:

torch.Size([1, 1, 1, 1])

PS:我还检查了结果是否正确,而不仅仅是尺寸。