膨胀Unet大小

时间:2018-10-23 06:08:00

标签: deep-learning conv-neural-network image-segmentation unity3d-unet

我正在尝试创建一个3D unet,用于具有扩张层的医学图像分割,以在不使模型变得过于沉重的情况下为我提供所需的所需接收场大小。我在固定编码器和解码器的尺寸以匹配时遇到很多麻烦。我的输入是[bs,4,96,96,64],输出是[bs,5,96,96,64],即每个像素有5种可能的类别可能性。如您所见,我定义了下块和上块,并且在编码器部分中两次最大池化。我在下面的代码后显示了大小和错误:

该错误出现在第一个上块中,因为它在第3维上插入了大小12,而相应的块在第3维上插入了大小13。有人可以帮我使它对称或至少可行吗?

class UNet_down_block(torch.nn.Module):
    def __init__(self, input_channel, output_channel, down_size):
        super(UNet_down_block, self).__init__()
        self.conv1 = torch.nn.Conv3d(input_channel, output_channel, 3, padding=1,dilation=1)
        self.bn1 = torch.nn.BatchNorm3d(output_channel)
        self.conv2 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1,dilation=2)
        self.bn2 = torch.nn.BatchNorm3d(output_channel)
        self.conv3 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1,dilation=2)
        self.bn3 = torch.nn.BatchNorm3d(output_channel)
        self.max_pool = torch.nn.MaxPool3d(2, 2)
        self.relu = torch.nn.ELU()
        self.down_size = down_size

    def forward(self, x):
        if self.down_size:
            x = self.max_pool(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x

class UNet_up_block(torch.nn.Module):
    def __init__(self, prev_channel, input_channel, output_channel):
        super(UNet_up_block, self).__init__()
#         self.up_sampling = torch.nn.functional.interpolate(scale_factor=2, mode='trilinear')
        self.conv1 = torch.nn.Conv3d(prev_channel + input_channel, output_channel, 3, padding=1)
        self.bn1 = torch.nn.BatchNorm3d(output_channel)
        self.conv2 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1)
        self.bn2 = torch.nn.BatchNorm3d(output_channel)
        self.conv3 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1)
        self.bn3 = torch.nn.BatchNorm3d(output_channel)
        self.relu = torch.nn.ELU()

    def forward(self, prev_feature_map, x):

        x = torch.nn.functional.interpolate(x,scale_factor=2, mode='trilinear')

        x = torch.cat((x, prev_feature_map), dim=1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x


class UNet(torch.nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.down_block1 = UNet_down_block(4, 24, False)
        self.down_block2 = UNet_down_block(24, 72, True)
        self.down_block3 = UNet_down_block(72, 148, True)
        self.down_block4 = UNet_down_block(148, 224, False)
        self.max_pool = torch.nn.MaxPool3d(2, 2)



        self.mid_conv1 = torch.nn.Conv3d(224, 224, 3, padding=1)
        self.bn1 = torch.nn.BatchNorm3d(224)
        self.mid_conv2 = torch.nn.Conv3d(224, 224, 3, padding=1)
        self.bn2 = torch.nn.BatchNorm3d(224)
        self.mid_conv3 = torch.nn.Conv3d(224, 224, 3, padding=1)
        self.bn3 = torch.nn.BatchNorm3d(224)


        self.up_block1 = UNet_up_block(224, 224, 148)
        self.up_block2 = UNet_up_block(148, 148, 72)
        self.up_block3 = UNet_up_block(72, 72, 24)
        self.up_block4 = UNet_up_block(24, 24, 8)

        self.last_conv1 = torch.nn.Conv3d(8, 4, 3, padding=1)
        self.last_bn = torch.nn.BatchNorm3d(4)
        self.last_conv2 = torch.nn.Conv3d(4, 1, 1, padding=0)
        self.relu = torch.nn.ELU()
        self.last_conv3 = torch.nn.Conv3d(1, 1, 1, padding=0)
        self.relu = torch.nn.ELU()

        self.conv1f=torch.nn.Conv2d(1, 5, 3,padding=1)
        self.conv2f=torch.nn.Conv2d(5, 5, 3,padding=1)
        self.conv3f=torch.nn.Conv2d(5, 5, 3,padding=1)

    def forward(self, x):
        print('input unet',x.size())
        self.x1 = self.down_block1(x)
        print("Block 1 shape:",self.x1.size())
        self.x2 = self.down_block2(self.x1)
        if self.x2.size()[2]==49:                                         ###*********************************** ifffff        if self.x2.size()[2]==49:
            self.x2=self.x2[:,:,1:,1:,:]


        print("Block 2 shape:",self.x2.size())
        self.x3 = self.down_block3(self.x2)
        print("Block 3 shape:",self.x3.size())


        self.x4 = self.down_block4(self.x3)
        print("Block 4 shape:",self.x4.size())


        self.xmid=self.max_pool(self.x4)
        self.xmid = self.relu(self.bn1(self.mid_conv1(self.xmid)))
        self.xmid = self.relu(self.bn2(self.mid_conv2(self.xmid)))
        self.xmid = self.relu(self.bn3(self.mid_conv3(self.xmid)))
        print("Block Mid shape:",self.xmid.size())



        x = self.up_block1(self.x4, self.xmid)
#         print("BlockU 1 shape:",x.size())
        x = self.up_block2(self.x3, x)
        print("BlockU 2 shape:",x.size())

        x = self.up_block3(self.x2, x)
        print("BlockU 3 shape:",x.size())

        if self.x1.size()[2]==98:                     ###*********************************** ifffff
            self.x1=self.x1[:,:,1:-1,1:-1,:]
#             print('chan98',self.x1.size())

        x = self.up_block4(self.x1, x)
        print("BlockU 4 shape:",x.size())


        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.last_conv2(x)  # of size [batch_size,1,h,w,depth] or [bs, modalities(1) ,96 ,96 , 64]

        x=x.view(batch_size,1,-1,64)
#         x=x.squeeze(1)  
#         print('input convf',x.size())
        conv=self.relu(self.conv1f(x))
        conv=self.relu(self.conv2f(conv))
        conv=self.conv3f(conv)


        try:
            conv=conv.view(batch_size,5,96,96,64)
        except:
            conv=conv.view(batch_size_val,5,96,96,64)
#         print('unet output',conv.size())

        return(conv)

以下是输出:

input unet torch.Size([1, 4, 96, 96, 64])
Block 1 shape: torch.Size([1, 24, 92, 92, 60])
Block 2 shape: torch.Size([1, 72, 42, 42, 26])
Block 3 shape: torch.Size([1, 148, 17, 17, 9])
Block 4 shape: torch.Size([1, 224, 13, 13, 5])
Block Mid shape: torch.Size([1, 224, 6, 6, 2]) 

Error:
     x = self.up_block1(self.x4, self.xmid)
        111 #         print("BlockU 1 shape:",x.size())
        112         x = self.up_block2(self.x3, x)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

<ipython-input-5-cbcdda025480> in forward(self, prev_feature_map, x)
     39         x = torch.nn.functional.interpolate(x,scale_factor=2, mode='trilinear')
     40 
---> 41         x = torch.cat((x, prev_feature_map), dim=1)
     42         x = self.relu(self.bn1(self.conv1(x)))
     43         x = self.relu(self.bn2(self.conv2(x)))

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 12 and 13 in dimension 2 at /opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/TH/generic/THTensorMath.cpp:3616

0 个答案:

没有答案