在 Pytorch 中计算 3D CNN 的填充

时间:2021-02-07 13:09:24

标签: neural-network pytorch conv-neural-network padding

我目前正在尝试将 3D CNN 应用于一组尺寸为 193 x 229 x 193 的图像,并希望通过每个卷积层保留相同的图像尺寸(类似于 tensorflow 的 padding=SAME) .我知道可以按如下方式计算填充:

S=Stride
P=Padding
W=Width
K=Kernal size

P = ((S-1)*W-S+K)/2

第一层的填充为 1:

P = ((1-1)*193-1+3)/2
P= 1.0

虽然我也得到了每个后续层的 1.0 的结果。有人有什么建议吗?抱歉,这里是初学者!

可重现的例子:

import torch
import torch.nn as nn

x = torch.randn(1, 1, 193, 229, 193)

padding = ((1-1)*96-1+3)/2
print(padding)

x = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=3, padding=1)(x)
print("shape after conv1: " + str(x.shape))
x = nn.Conv3d(in_channels=8, out_channels=8, kernel_size=3,padding=1)(x)
x = nn.BatchNorm3d(8)(x) 
print("shape after conv2 + batch norm: " + str(x.shape))
x = nn.ReLU()(x)
print("shape after reLU:" + str(x.shape))
x = nn.MaxPool3d(kernel_size=2, stride=2)(x)
print("shape after max pool" + str(x.shape))
x = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=3,padding=1)(x)
print("shape after conv3: " + str(x.shape))
x = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=3,padding=1)(x)
print("shape after conv4: " + str(x.shape))

当前输出:

shape after conv1: torch.Size([1, 8, 193, 229, 193])
shape after conv2 + batch norm: torch.Size([1, 8, 193, 229, 193])
shape after reLU:torch.Size([1, 8, 193, 229, 193])
shape after max pooltorch.Size([1, 8, 96, 114, 96])
shape after conv3: torch.Size([1, 16, 96, 114, 96])
shape after conv4: torch.Size([1, 16, 96, 114, 96])

所需的输出:

shape after conv1: torch.Size([1, 8, 193, 229, 193])
shape after conv2 + batch norm: torch.Size([1, 8, 193, 229, 193])
...
shape after conv3: torch.Size([1, 16, 193, 229, 193])
shape after conv4: torch.Size([1, 16, 193, 229, 193])

1 个答案:

答案 0 :(得分:1)

TLDR;您的公式也适用于 nn.MaxPool3d

您正在使用内核大小为 2(隐式 (2,2,2))的最大池层,步幅为 2(隐式 (2,2,2))。这意味着对于每个 2x2x2 块,您只能获得一个值。换句话说 - 顾名思义:只有来自每个 2x2x2 块的最大值被合并到输出数组中。

这就是为什么你要从 (1, 8, 193, 229, 193)(1, 8, 96, 114, 96)(注意除以 2)。

当然,如果您在 kernel_size=3 上设置 stride=1nn.MaxPool3d,您将保留块的形状。


#x 为输入形状,#w 为内核形状。如果我们希望输出具有相同的大小,则 #x = floor((#x + 2p - #w)/s + 1) 需要为真。那是 2p = s(#x - 1) - #x + #w = #x(s - 1) + #w - s(你的公式)

由于 s = 2#w = 2,那么 2p = #x 这是不可能的。