我目前正在尝试将 3D CNN 应用于一组尺寸为 193 x 229 x 193 的图像,并希望通过每个卷积层保留相同的图像尺寸(类似于 tensorflow 的 padding=SAME
) .我知道可以按如下方式计算填充:
S=Stride
P=Padding
W=Width
K=Kernal size
P = ((S-1)*W-S+K)/2
第一层的填充为 1:
P = ((1-1)*193-1+3)/2
P= 1.0
虽然我也得到了每个后续层的 1.0
的结果。有人有什么建议吗?抱歉,这里是初学者!
可重现的例子:
import torch
import torch.nn as nn
x = torch.randn(1, 1, 193, 229, 193)
padding = ((1-1)*96-1+3)/2
print(padding)
x = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=3, padding=1)(x)
print("shape after conv1: " + str(x.shape))
x = nn.Conv3d(in_channels=8, out_channels=8, kernel_size=3,padding=1)(x)
x = nn.BatchNorm3d(8)(x)
print("shape after conv2 + batch norm: " + str(x.shape))
x = nn.ReLU()(x)
print("shape after reLU:" + str(x.shape))
x = nn.MaxPool3d(kernel_size=2, stride=2)(x)
print("shape after max pool" + str(x.shape))
x = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=3,padding=1)(x)
print("shape after conv3: " + str(x.shape))
x = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=3,padding=1)(x)
print("shape after conv4: " + str(x.shape))
当前输出:
shape after conv1: torch.Size([1, 8, 193, 229, 193])
shape after conv2 + batch norm: torch.Size([1, 8, 193, 229, 193])
shape after reLU:torch.Size([1, 8, 193, 229, 193])
shape after max pooltorch.Size([1, 8, 96, 114, 96])
shape after conv3: torch.Size([1, 16, 96, 114, 96])
shape after conv4: torch.Size([1, 16, 96, 114, 96])
所需的输出:
shape after conv1: torch.Size([1, 8, 193, 229, 193])
shape after conv2 + batch norm: torch.Size([1, 8, 193, 229, 193])
...
shape after conv3: torch.Size([1, 16, 193, 229, 193])
shape after conv4: torch.Size([1, 16, 193, 229, 193])
答案 0 :(得分:1)
TLDR;您的公式也适用于 nn.MaxPool3d
您正在使用内核大小为 2
(隐式 (2,2,2)
)的最大池层,步幅为 2
(隐式 (2,2,2)
)。这意味着对于每个 2x2x2
块,您只能获得一个值。换句话说 - 顾名思义:只有来自每个 2x2x2
块的最大值被合并到输出数组中。
这就是为什么你要从 (1, 8, 193, 229, 193)
到 (1, 8, 96, 114, 96)
(注意除以 2
)。
当然,如果您在 kernel_size=3
上设置 stride=1
和 nn.MaxPool3d
,您将保留块的形状。
让 #x
为输入形状,#w
为内核形状。如果我们希望输出具有相同的大小,则 #x = floor((#x + 2p - #w)/s + 1)
需要为真。那是 2p = s(#x - 1) - #x + #w = #x(s - 1) + #w - s
(你的公式)
由于 s = 2
和 #w = 2
,那么 2p = #x
这是不可能的。