使用Pytorch使用CNN进行线性回归:输入形状与目标形状不匹配:输入[400 x 1],目标[200 x 1]

时间:2018-09-14 03:47:56

标签: python linear-regression pytorch

让我先解释一下目标。假设我有1000张图片,每张图片都有相关的质量得分[范围为0-10]。现在,我正在尝试使用CNN进行回归(在PyTorch中)进行图像质量评估。我已将图像分为相等大小的补丁。现在,我创建了一个CNN网络以执行线性回归。

以下是代码:

class MultiLabelNN(nn.Module):
    def __init__(self):
        super(MultiLabelNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(3200,1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = x.view(-1, 3200)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x  

运行此网络代码时,出现以下错误

输入形状和目标形状不匹配:输入[400 x 1],目标[200 x 1]

目标形状为[200x1]是因为我已采用200的批处理大小。我发现了以下解决方案:如果更改“ self.fc1 = nn.Linear(3200,1024)”和“ x = x.view (-1,3200)”,从3200到6400,我的代码运行没有任何错误。

类似地,如果输入12800而不是6400,它将输入错误并且目标形状不匹配:输入[100 x 1],目标[200 x 1]

现在,我的疑问是我无法理解其背后的原因。如果我将200张图像作为网络输入,那么当我从卷积层移到完全连接层时,为什么更改参数时输入形状会受到影响。我希望我已经明确提到了我的疑问。即使我有人有任何疑问,也请问我。这将是一个很大的帮助。预先感谢。

1 个答案:

答案 0 :(得分:1)

class MultiLabelNN(nn.Module):
    def __init__(self):
        super(MultiLabelNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(6400,1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 1)

   def forward(self, x):
       #shape of x is (b_s, 32,32,1)
       x = self.conv1(x) #shape of x is (b_s, 28,28,132)
       x = F.relu(x)
       x = self.pool(x) #shape of x now becomes (b_s X 14 x 14 x 32)
       x = self.conv2(x) # shape(b_s, 10x10x64)
       x = F.relu(x)#size is (b_s x 10 x 10 x 64)
       x = x.view(-1, 3200) # shape of x is now(b_s*2, 3200)
       #this is the problem 
       #you can fc1 to be of shape (6400,1024) and that will work 
       x = self.fc1(x)
       x = F.relu(x)
       x = self.fc2(x)
       x = F.relu(x)
       x = self.fc3(x)
       return x  

我认为这应该有效。让我知道是否仍然存在一些错误。