如何解决此模型的尺寸不匹配?

时间:2020-04-20 17:40:51

标签: pytorch

data_transforms = {
'train': transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),

batchsize = 4

class Net(nn.Module):
def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=1)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(64, 128, kernel_size=5, padding=1)
    self.fc1 = nn.Linear(256*6*6, 4096)
    self.fc2 = nn.Linear(4096, 4096)
    self.fc3 = nn.Linear(4096, 2)

def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.reshape(x.size(0), -1)
    print(x.shape)
    x = F.relu(self.fc1(x))  
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
RuntimeError: size mismatch, m1: [4 x 373248], m2: [9216 x 4096] at C:/w/1/s/tmp_conda_3.8_075429/conda/conda-bld/pytorch_1579852542185/work/aten/src\THC/generic/THCTensorMathBlas.cu:290

1 个答案:

答案 0 :(得分:0)

使用提供的调试信息,我们只能说其中一层存在大小不匹配。

看着错误,似乎错误在第一个线性层中。您应该将其大小更改为373248,而不是256 * 6 * 6。

从打印语句print(x.shape)的输出中也应该清楚这一点。