PyTorch运行时错误:参数无效0:张量的大小必须匹配,但维度1除外

时间:2019-01-29 09:29:00

标签: deep-learning pytorch

我有一个PyTorch模型,我正在尝试通过执行前向传递来对其进行测试。这是代码:

class ResBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(ResBlock, self).__init__()
        self.conv1x1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False)
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        #batch normalization
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.stride = stride

    def forward(self, x):
        residual = self.conv1x1(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        #adding the skip connection
        out += residual
        out = self.relu(out)

        return out

class ResUnet (nn.Module):

    def __init__(self, in_shape,  num_classes):
        super(ResUnet, self).__init__()
        in_channels, height, width = in_shape
        #
        #self.L1 = IncResBlock(in_channels,64)
        self.e1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2,padding=1),
            ResBlock(64,64))


        self.e2 = nn.Sequential(
            nn.LeakyReLU(0.2,),
            nn.Conv2d(64, 128, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(128),
            ResBlock(128,128))
        #
        self.e2add = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(128))
        #
        ##
        self.e3 = nn.Sequential(
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,),
            nn.Conv2d(128,256, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(256),
            ResBlock(256,256))

        self.e4 = nn.Sequential(
            nn.LeakyReLU(0.2,),
            nn.Conv2d(256,512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            ResBlock(512,512))
        #
        self.e4add = nn.Sequential(
            nn.Conv2d(512,512, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(512)) 
        #
        self.e5 = nn.Sequential(
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(512,512, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2,),
            nn.Conv2d(512,512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            ResBlock(512,512))
        #
        #
        self.e6 = nn.Sequential(
            nn.LeakyReLU(0.2,),
            nn.Conv2d(512,512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512), 
            ResBlock(512,512))
        #
        self.e6add = nn.Sequential(
            nn.Conv2d(512,512, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(512)) 
        #
        self.e7 = nn.Sequential(
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(512,512, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2,),
            nn.Conv2d(512,512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            ResBlock(512,512))
        #
        self.e8 = nn.Sequential(
            nn.LeakyReLU(0.2,),
            nn.Conv2d(512,512, kernel_size=4, stride=2,padding=1))
            #nn.BatchNorm2d(512))

        self.d1 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.Dropout(p=0.5),
            ResBlock(512,512))
        #
        self.d2 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.Dropout(p=0.5),
            ResBlock(512,512))
        #
        self.d3 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            nn.Dropout(p=0.5),
            ResBlock(512,512))
        #
        self.d4 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(512),
            ResBlock(512,512))

        #
        self.d5 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(256),
            ResBlock(256,256))
        #
        self.d6 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(128),
            ResBlock(128,128))
        #
        self.d7 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2,padding=1),
            nn.BatchNorm2d(64),
            ResBlock(64,64))
        #
        self.d8 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2,padding=1))
            #nn.BatchNorm2d(64),
            #nn.ReLU())

        self.out_l = nn.Sequential(
            nn.Conv2d(64,num_classes,kernel_size=1,stride=1))
            #nn.ReLU())

    def forward(self, x):

        #Image Encoder

        #### Encoder #####

        en1 = self.e1(x)

        en2 = self.e2(en1)
        en2add = self.e2add(en2)

        en3 = self.e3(en2add)

        en4 = self.e4(en3)
        en4add = self.e4add(en4)

        en5 = self.e5(en4add)

        en6 = self.e6(en5)
        en6add = self.e6add(en6)

        en7 = self.e7(en6add)

        en8 = self.e8(en7)

        #### Decoder ####
        de1_ = self.d1(en8)
        de1 = torch.cat([en7,de1_],1)

        de2_ = self.d2(de1)
        de2 = torch.cat([en6add,de2_],1)


        de3_ = self.d3(de2)
        de3 = torch.cat([en5,de3_],1)


        de4_ = self.d4(de3)
        de4 = torch.cat([en4add,de4_],1)


        de5_ = self.d5(de4)
        de5 = torch.cat([en3,de5_],1)

        de6_ = self.d6(de5)
        de6 = torch.cat([en2add,de6_],1)

        de7_ = self.d7(de6)
        de7 = torch.cat([en1,de7_],1)
        de8 = self.d8(de7)

        out_l_mask = self.out_l(de8)

        return out_l_mask  

这是我尝试测试的方式:

modl = ResUnet((1,512,512), 1)
x = torch.rand(1, 1, 512, 512)
modl(x)

这很好用,对于64倍数的任何大小也一样。

如果我尝试:

modl = ResUnet((1,320,320), 1)
x = torch.rand(1, 1, 320, 320)
modl(x)

它抛出一个错误

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-46-4ddc821c365b> in <module>
----> 1 modl(x)

~/.conda/envs/torch0.4/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

<ipython-input-36-f9eeefa3c0b8> in forward(self, x)
    221         de2_ = self.d2(de1)
    222         #print de2_.size()
--> 223         de2 = torch.cat([en6add,de2_],1)
    224         #print de2.size()
    225 

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 5 and 4 in dimension 2 at /opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/TH/generic/THTensorMath.cpp:3616

我认为问题是由于输入大小不是2的幂而引起的,但是我不确定如何针对给定的输入量(320,320)进行纠正。

1 个答案:

答案 0 :(得分:0)

此问题源于下采样(编码器)路径和上采样(解码器)路径中的变量大小不匹配。您的代码庞大且难以理解,但是通过插入print语句,我们可以检查

  1. en6add的大小为[1, 512, 5, 5]
  2. en7[1, 512, 2, 2]
  3. en8[1, 512, 1, 1]
  4. 然后向上采样采用2的幂:de1_[1, 512, 2, 2]
  5. de1 [1, 1024, 2, 2]
  6. de2_ [1, 512, 4, 4]

此时,您尝试将其与en6add进行连接,因此显然创建de2_的代码不够“充分采样”。我的强烈猜测是,您需要注意nn.ConvTranspose2doutput_padding参数,并可能在几个地方将其设置为1。我会尽力为您修复此错误,但该示例距离minimal太远,以至于我无法全神贯注。