如何解决错误:PyTorch中预期的输入批次大小与目标批次大小不匹配?

时间:2020-06-02 17:31:35

标签: python machine-learning deep-learning pytorch

我试图通过PyTorch在CIFAR10数据集上创建逻辑模型。但是我遇到一个错误:

ValueError:预期输入的batch_size(900)匹配目标batch_size(300)。

我认为正在发生的事是3 * 100是300。所以RGB图像的3轴可能正在这样做,但是我不知道该怎么解决。

这些是我的超参数。

batch_size = 100
learning_rate = 0.001

# Other constants
input_size = 32*32
num_classes = 10

在这里,我将数据分为训练,验证和测试数据。


transform_train = transforms.Compose([transforms.Resize((32,32)),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomRotation(10),
                                      transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
                                      transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ])


transform = transforms.Compose([transforms.Resize((32,32)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                               ])
training_dataset = CIFAR10(root='D:\PyTorch\cifar-10-python', train=True, download=True, transform=transform_train)
train_ds, val_ds = random_split(training_dataset, [40000, 10000])
test_ds = CIFAR10(root='D:\PyTorch\cifar-10-python', train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=100, shuffle=True)
val_loader = DataLoader(val_ds, batch_size = 100, shuffle = False)
test_loader = DataLoader(test_ds, batch_size = 100, shuffle=False)

这是模型。

class CifarModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(input_size,  num_classes)
    def forward(self, xb):
        xb = xb.view(-1, 32*32)
        #xb = xb.reshape(-1, 784)
        print(xb.shape)
        out = self.linear(xb)
        return out

    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss

    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc.detach()}

    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

    def epoch_end(self, epoch, result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))

model = CifarModel()
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))
def evaluate(model, val_loader):
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):
        # Training Phase 
        for batch in train_loader:
            loss = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # Validation phase
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result)
        history.append(result)
    return history
evaluate(model, val_loader)

这是我运行评估函数时遇到的错误:

torch.Size([900, 1024])

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-23-3621eab8de1a> in <module>
     21         history.append(result)
     22     return history
---> 23 evaluate(model, val_loader)

<ipython-input-23-3621eab8de1a> in evaluate(model, val_loader)
      3     return torch.tensor(torch.sum(preds == labels).item() / len(preds))
      4 def evaluate(model, val_loader):
----> 5     outputs = [model.validation_step(batch) for batch in val_loader]
      6     return model.validation_epoch_end(outputs)
      7 

<ipython-input-23-3621eab8de1a> in <listcomp>(.0)
      3     return torch.tensor(torch.sum(preds == labels).item() / len(preds))
      4 def evaluate(model, val_loader):
----> 5     outputs = [model.validation_step(batch) for batch in val_loader]
      6     return model.validation_epoch_end(outputs)
      7 

<ipython-input-22-c9e17d21eaff> in validation_step(self, batch)
     19         images, labels = batch
     20         out = self(images)                    # Generate predictions
---> 21         loss = F.cross_entropy(out, labels)   # Calculate loss
     22         acc = accuracy(out, labels)           # Calculate accuracy
     23         return {'val_loss': loss.detach(), 'val_acc': acc.detach()}

~\Anaconda3\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2019     if size_average is not None or reduce is not None:
   2020         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2021     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2022 
   2023 

~\Anaconda3\lib\site-packages\torch\nn\functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1834     if input.size(0) != target.size(0):
   1835         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 1836                          .format(input.size(0), target.size(0)))
   1837     if dim == 2:
   1838         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (900) to match target batch_size (300).

1 个答案:

答案 0 :(得分:1)

我看到的一个问题是以下行:

xb = xb.view(-1, 32*32) 

在这里您说的是输入图像只有一个通道。换句话说,灰度。对其进行更改以反映通道数(RGB):

xb = xb.view(-1, 32*32*3)