我正在尝试创建一个3D unet,用于具有扩张层的医学图像分割,以在不使模型变得过于沉重的情况下为我提供所需的所需接收场大小。我在固定编码器和解码器的尺寸以匹配时遇到很多麻烦。我的输入是[bs,4,96,96,64],输出是[bs,5,96,96,64],即每个像素有5种可能的类别可能性。如您所见,我定义了下块和上块,并且在编码器部分中两次最大池化。我在下面的代码后显示了大小和错误:
该错误出现在第一个上块中,因为它在第3维上插入了大小12,而相应的块在第3维上插入了大小13。有人可以帮我使它对称或至少可行吗?
class UNet_down_block(torch.nn.Module):
def __init__(self, input_channel, output_channel, down_size):
super(UNet_down_block, self).__init__()
self.conv1 = torch.nn.Conv3d(input_channel, output_channel, 3, padding=1,dilation=1)
self.bn1 = torch.nn.BatchNorm3d(output_channel)
self.conv2 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1,dilation=2)
self.bn2 = torch.nn.BatchNorm3d(output_channel)
self.conv3 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1,dilation=2)
self.bn3 = torch.nn.BatchNorm3d(output_channel)
self.max_pool = torch.nn.MaxPool3d(2, 2)
self.relu = torch.nn.ELU()
self.down_size = down_size
def forward(self, x):
if self.down_size:
x = self.max_pool(x)
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
return x
class UNet_up_block(torch.nn.Module):
def __init__(self, prev_channel, input_channel, output_channel):
super(UNet_up_block, self).__init__()
# self.up_sampling = torch.nn.functional.interpolate(scale_factor=2, mode='trilinear')
self.conv1 = torch.nn.Conv3d(prev_channel + input_channel, output_channel, 3, padding=1)
self.bn1 = torch.nn.BatchNorm3d(output_channel)
self.conv2 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1)
self.bn2 = torch.nn.BatchNorm3d(output_channel)
self.conv3 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1)
self.bn3 = torch.nn.BatchNorm3d(output_channel)
self.relu = torch.nn.ELU()
def forward(self, prev_feature_map, x):
x = torch.nn.functional.interpolate(x,scale_factor=2, mode='trilinear')
x = torch.cat((x, prev_feature_map), dim=1)
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
return x
class UNet(torch.nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.down_block1 = UNet_down_block(4, 24, False)
self.down_block2 = UNet_down_block(24, 72, True)
self.down_block3 = UNet_down_block(72, 148, True)
self.down_block4 = UNet_down_block(148, 224, False)
self.max_pool = torch.nn.MaxPool3d(2, 2)
self.mid_conv1 = torch.nn.Conv3d(224, 224, 3, padding=1)
self.bn1 = torch.nn.BatchNorm3d(224)
self.mid_conv2 = torch.nn.Conv3d(224, 224, 3, padding=1)
self.bn2 = torch.nn.BatchNorm3d(224)
self.mid_conv3 = torch.nn.Conv3d(224, 224, 3, padding=1)
self.bn3 = torch.nn.BatchNorm3d(224)
self.up_block1 = UNet_up_block(224, 224, 148)
self.up_block2 = UNet_up_block(148, 148, 72)
self.up_block3 = UNet_up_block(72, 72, 24)
self.up_block4 = UNet_up_block(24, 24, 8)
self.last_conv1 = torch.nn.Conv3d(8, 4, 3, padding=1)
self.last_bn = torch.nn.BatchNorm3d(4)
self.last_conv2 = torch.nn.Conv3d(4, 1, 1, padding=0)
self.relu = torch.nn.ELU()
self.last_conv3 = torch.nn.Conv3d(1, 1, 1, padding=0)
self.relu = torch.nn.ELU()
self.conv1f=torch.nn.Conv2d(1, 5, 3,padding=1)
self.conv2f=torch.nn.Conv2d(5, 5, 3,padding=1)
self.conv3f=torch.nn.Conv2d(5, 5, 3,padding=1)
def forward(self, x):
print('input unet',x.size())
self.x1 = self.down_block1(x)
print("Block 1 shape:",self.x1.size())
self.x2 = self.down_block2(self.x1)
if self.x2.size()[2]==49: ###*********************************** ifffff if self.x2.size()[2]==49:
self.x2=self.x2[:,:,1:,1:,:]
print("Block 2 shape:",self.x2.size())
self.x3 = self.down_block3(self.x2)
print("Block 3 shape:",self.x3.size())
self.x4 = self.down_block4(self.x3)
print("Block 4 shape:",self.x4.size())
self.xmid=self.max_pool(self.x4)
self.xmid = self.relu(self.bn1(self.mid_conv1(self.xmid)))
self.xmid = self.relu(self.bn2(self.mid_conv2(self.xmid)))
self.xmid = self.relu(self.bn3(self.mid_conv3(self.xmid)))
print("Block Mid shape:",self.xmid.size())
x = self.up_block1(self.x4, self.xmid)
# print("BlockU 1 shape:",x.size())
x = self.up_block2(self.x3, x)
print("BlockU 2 shape:",x.size())
x = self.up_block3(self.x2, x)
print("BlockU 3 shape:",x.size())
if self.x1.size()[2]==98: ###*********************************** ifffff
self.x1=self.x1[:,:,1:-1,1:-1,:]
# print('chan98',self.x1.size())
x = self.up_block4(self.x1, x)
print("BlockU 4 shape:",x.size())
x = self.relu(self.last_bn(self.last_conv1(x)))
x = self.last_conv2(x) # of size [batch_size,1,h,w,depth] or [bs, modalities(1) ,96 ,96 , 64]
x=x.view(batch_size,1,-1,64)
# x=x.squeeze(1)
# print('input convf',x.size())
conv=self.relu(self.conv1f(x))
conv=self.relu(self.conv2f(conv))
conv=self.conv3f(conv)
try:
conv=conv.view(batch_size,5,96,96,64)
except:
conv=conv.view(batch_size_val,5,96,96,64)
# print('unet output',conv.size())
return(conv)
以下是输出:
input unet torch.Size([1, 4, 96, 96, 64])
Block 1 shape: torch.Size([1, 24, 92, 92, 60])
Block 2 shape: torch.Size([1, 72, 42, 42, 26])
Block 3 shape: torch.Size([1, 148, 17, 17, 9])
Block 4 shape: torch.Size([1, 224, 13, 13, 5])
Block Mid shape: torch.Size([1, 224, 6, 6, 2])
Error:
x = self.up_block1(self.x4, self.xmid)
111 # print("BlockU 1 shape:",x.size())
112 x = self.up_block2(self.x3, x)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
475 result = self._slow_forward(*input, **kwargs)
476 else:
--> 477 result = self.forward(*input, **kwargs)
478 for hook in self._forward_hooks.values():
479 hook_result = hook(self, input, result)
<ipython-input-5-cbcdda025480> in forward(self, prev_feature_map, x)
39 x = torch.nn.functional.interpolate(x,scale_factor=2, mode='trilinear')
40
---> 41 x = torch.cat((x, prev_feature_map), dim=1)
42 x = self.relu(self.bn1(self.conv1(x)))
43 x = self.relu(self.bn2(self.conv2(x)))
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 12 and 13 in dimension 2 at /opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/TH/generic/THTensorMath.cpp:3616