pytorch如何获得每个图像属于一类的概率

时间:2019-11-08 12:21:56

标签: python classification pytorch

我是Pytorch的新手。我使用图像数据集训练和测试了线性分类器(nn.Linear),该图像数据集具有8个类别,其batch_size = 35。 在测试时,我想查看给定图像属于这8个类别中的任何一个的概率。这就是为什么我打印output.data变量的原因。但是这些数字大于1,并且不等于1。(我附有测试代码) 所以,我的问题是这些数字是什么意思?

谢谢!

correct = 0
total = 0
with torch.no_grad():
    for data in dataloaders['test']:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        print(outputs.data)

        _, predicted = torch.max(outputs.data, 1)
        print(predicted)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 4000 test images: %d %%' % (
100 * correct / total))

1 个答案:

答案 0 :(得分:1)

您将logits作为神经网络的输出。

在输出上使用torch.nn.Softmax将值压缩到(0,1)范围内。

顺便说一句。您的网络应该输出对数,因为pytorch的损失(在这种情况下为torch.nn.CrossEntropyLoss)旨在在数值稳定的情况下与它们一起工作。