如何使用pytorch在网络模型内进行3D体积的中心裁切

时间:2019-08-15 22:55:31

标签: python 3d conv-neural-network pytorch crop

keras中,存在Cropping3D层,用于在神经网络内部对3D体的张量进行中心裁剪。但是,尽管在pytorch中有torchvision.transforms.CenterCrop(size)用于2D图像,我却找不到类似的东西。

如何在网络内部进行裁剪?否则,我需要在预处理中做这件事,由于特定原因,这是我要做的最后一件事。

我是否需要编写一个自定义层,例如沿每个轴切片输入张量?希望对此有所启发

1 个答案:

答案 0 :(得分:0)

在PyTorch中,您不一定需要为所有内容编写图层,通常您可以在前进过程中直接进行所需的操作。在需要计算梯度的火炬张量上操作时,需要记住的基本规则是

  1. 请勿将割炬张量转换为其他类型进行计算(例如,使用torch.sum而不是转换为numpy并使用numpy.sum)。
  2. 请勿执行就地操作(例如,更改张量的一个元素或使用就地运算符,因此请使用x = x + ...而不是x += ...)。

也就是说,您可以使用切片,也许看起来像这样

def forward(self, x):
    ...
    x = self.conv3(x)
    x = x[:, :, 5:20, 5:20]    # crop out part of the feature map
    x = self.relu3(x)
    ...