在keras
中,存在Cropping3D
层,用于在神经网络内部对3D体的张量进行中心裁剪。但是,尽管在pytorch中有torchvision.transforms.CenterCrop(size)
用于2D图像,我却找不到类似的东西。
如何在网络内部进行裁剪?否则,我需要在预处理中做这件事,由于特定原因,这是我要做的最后一件事。
我是否需要编写一个自定义层,例如沿每个轴切片输入张量?希望对此有所启发
答案 0 :(得分:0)
在PyTorch中,您不一定需要为所有内容编写图层,通常您可以在前进过程中直接进行所需的操作。在需要计算梯度的火炬张量上操作时,需要记住的基本规则是
torch.sum
而不是转换为numpy并使用numpy.sum
)。x = x + ...
而不是x += ...
)。也就是说,您可以使用切片,也许看起来像这样
def forward(self, x):
...
x = self.conv3(x)
x = x[:, :, 5:20, 5:20] # crop out part of the feature map
x = self.relu3(x)
...