用于计算卷积pytorch(googlenet)中的填充的公式

时间:2020-11-03 10:52:10

标签: tensorflow deep-learning computer-vision pytorch

我正在pytorch中从头开始实现googlenet(较小版本)。架构如下:

enter image description here

对于下采样模块,我具有以下代码:

   class DownSampleModule(nn.Module):
   def __init__(self, in_channel, ch3, w):
       super(DownSampleModule, self).__init__()
       kernel_size = 3
       padding = (kernel_size-1)/2

       self.branch1 = nn.Sequential(
           ConvBlock(in_channel, ch3, kernel_size = 3,stride=2, padding=int(padding))
       )
       self.branch2 = nn.Sequential(
           nn.MaxPool2d(3, stride=2, padding=0, ceil_mode=True)
       )
   def forward(self, x):
       branch1 = self.branch1(x)
       branch2 = self.branch2(x)
      
       return torch.cat([padded_tensor, branch2], 1)

ConvBlock来自此模块

class ConvBlock(nn.Module):
   def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
       super(ConvBlock, self).__init__()
       #padding = (kernel_size -1 )/2
       #print(padding)
       self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
       self.bn = nn.BatchNorm2d(out_channels)
       self.act = nn.ReLU()
       
   def forward(self, x):
       x = self.conv(x)
       x = self.bn(x)
       x = self.act(x)
       return x

基本上,我们正在创建两个分支:卷积模块和最大池。然后,将这两个分支的输出串联在通道维度上。

但是,我遇到以下问题:

  • 首先,我们叫self.pool1 = DownSampleModule(in_channel=80, ch3 = 80, w=30)。两个分支的尺寸相似。这些是:
Downsample Convolution:torch.Size([1, 80, 15, 15])
Maxpool Convolution:torch.Size([1, 80, 15, 15])
  • 但是,当我们致电self.pool2 = DownSampleModule(in_channel = 144, ch3 = 96, w=15)时。尺寸各不相同,因此无法将其串联起来。
Downsample Convolution:torch.Size([1, 96, 8, 8])
Maxpool Convolution:torch.Size([1, 144, 7, 7])

有人知道计算正确填充的公式吗?谢谢。

在Keras中,您可以只设置padding =“ same”或“ valid”,但是pytorch不支持它。

1 个答案:

答案 0 :(得分:1)

您的maxpoolconv分支具有相同的输入,如果为内核大小,步幅和填充赋予相同的参数,它们将产生形状相同的输出。因此,仅将padding = 0替换为padding = int(padding)就足以使两个分支都兼容。

ceil_mode也应设置为False。当结果维不是整数时,conv2d的舍入行为将使用floor,因此您希望maxpool也这样做。

顺便说一句,您可以删除自己的nn.Sequential。您的层“序列”仅由一层组成,所以...不是真正的序列:)