F.relu(self.fc1(x))导致RuntimeError问题

时间:2018-11-19 00:59:40

标签: python pytorch

我已经为我的训练和验证数据集实现了以下CNN 分别包含90和20个图像,分为3类:

def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 50)
    self.fc3 = nn.Linear(50, len(classes))


def forward(self, x):
    print(x.shape)
    x = self.pool(F.relu(self.conv1(x)))
    print(x.shape)
    x = self.pool(F.relu(self.conv2(x)))
    print(x.shape)
    x = x.view(x.size(0),-1)
        #x = x.view(-1,x.size(1)*x.size(2)*x.size(3))
        #x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

运行时出现以下错误:

RuntimeError: size mismatch, m1: [1 x 214720], m2: [400 x 120] at /opt/conda/conda-bld/pytorch_1524584710464/work/aten/src/TH/generic/THTensorMath.c:2033 for x = F.relu(self.fc1(x))

有人可以建议我该怎么做才能摆脱这个问题吗?我通过遵循一些线程更改了x.view(...)。但是,在这种情况下没有帮助。

1 个答案:

答案 0 :(得分:1)

in_channelsself.fc1的大小取决于输入图像的大小,而不取决于内核的大小。

就您而言, self.fc1 = nn.Linear(16 * 5 * 5, 120)应该是nn.Linear(16 * image_size * image_size) 其中image_size:是最后一个卷积层中图像的大小。

示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Net(nn.Module):
    def __init__(self, classes):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5, padding=2)
        self.fc1 = nn.Linear(16 * 25 * 25, 120)
        self.fc2 = nn.Linear(120, 50)
        self.fc3 = nn.Linear(50, classes)

    def forward(self, x):
        print('one', x.shape)
        x = self.pool(F.relu(self.conv1(x)))
        print('two', x.shape)
        x = self.pool(F.relu(self.conv2(x)))
        print('three', x.shape)
        x = x.view(-1, np.product(x.shape[1:]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

x = torch.rand((32, 3, 100, 100))
net = Net(2)
out= net(x)
print('out', out.shape)

one torch.Size([32, 3, 100, 100])
two torch.Size([32, 6, 50, 50])
three torch.Size([32, 16, 25, 25])
out torch.Size([32, 2])