Pytorch验证模型错误:预期输入batch_size(3)匹配目标batch_size(4)

时间:2018-09-05 06:57:50

标签: python neural-network pytorch

我正在Pytorch中构建一个NN,应该对102个类别进行分类。

我具有以下验证功能:

def validation(model, testloader, criterion):
    test_loss = 0
    accuracy = 0

    for inputs, classes in testloader:
        inputs = inputs.to('cuda')
        output = model.forward(inputs)
        test_loss += criterion(output, labels).item()

        ps = torch.exp(output)
        equality = (labels.data == ps.max(dim=1)[1])
        accuracy += equality.type(torch.FloatTensor).mean()

    return test_loss, accuracy

培训代码(通话validation):

epochs = 3
print_every = 40
steps = 0
running_loss = 0
testloader = dataloaders['test']

# change to cuda
model.to('cuda')

for e in range(epochs):
    running_loss = 0
    for ii, (inputs, labels) in enumerate(dataloaders['train']):
        steps += 1

        inputs, labels = inputs.to('cuda'), labels.to('cuda')

        optimizer.zero_grad()

        # Forward and backward passes
        outputs = model.forward(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if steps % print_every == 0:
            model.eval()
            with torch.no_grad():
                test_loss, accuracy = validation(model, testloader, criterion)

            print("Epoch: {}/{}.. ".format(e+1, epochs),
                  "Training Loss: {:.3f}.. ".format(running_loss/print_every),
                  "Test Loss: {:.3f}.. ".format(test_loss/len(testloader)),
                  "Test Accuracy: {:.3f}".format(accuracy/len(testloader)))

            running_loss = 0
            model.train()

我收到此错误消息:

ValueError: Expected input batch_size (3) to match target batch_size (4).

完整追溯:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-63-f9f67ed13b94> in <module>()
     28             model.eval()
     29             with torch.no_grad():
---> 30                 test_loss, accuracy = validation(model, testloader, criterion)
     31 
     32             print("Epoch: {}/{}.. ".format(e+1, epochs),

<ipython-input-62-dbc77acbda5e> in validation(model, testloader, criterion)
      6         inputs = inputs.to('cuda')
      7         output = model.forward(inputs)
----> 8         test_loss += criterion(output, labels).item()
      9 
     10         ps = torch.exp(output)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    489             result = self._slow_forward(*input, **kwargs)
490         else:
--> 491             result = self.forward(*input, **kwargs)
    492         for hook in self._forward_hooks.values():
    493             hook_result = hook(self, input, result)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    191         _assert_no_grad(target)
    192         return F.nll_loss(input, target, self.weight, self.size_average,
--> 193                           self.ignore_index, self.reduce)
    194 
    195 

/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce)
   1328     if input.size(0) != target.size(0):
   1329         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 1330                          .format(input.size(0), target.size(0)))
   1331     if dim == 2:
   1332         return torch._C._nn.nll_loss(input, target, weight, size_average, ignore_index, reduce)

ValueError: Expected input batch_size (3) to match target batch_size (4).

我不知道错误的来源。确实,没有验证码,训练部分就可以完美地工作。

1 个答案:

答案 0 :(得分:3)

在您的验证功能中,

def validation(model, testloader, criterion):
    test_loss = 0
    accuracy = 0

    for inputs, classes in testloader:
        inputs = inputs.to('cuda')
        output = model.forward(inputs)
        test_loss += criterion(output, labels).item()

        ps = torch.exp(output)
        equality = (labels.data == ps.max(dim=1)[1])
        accuracy += equality.type(torch.FloatTensor).mean()

    return test_loss, accuracy

您正在测试加载器上进行迭代,并将值传递给变量inputs, classes,但是您正在将labels传递给条件。