CNN:在许多人中做出最自信的预测

时间:2019-08-09 04:01:07

标签: conv-neural-network pytorch

我正在训练CNN进行图像分类。相同的对象(然后带有相同的标签)在测试集中出现两次(如两个视点)。我想在上课时充分利用这一点。

现在,最后一层是Linear层(PyTorch),我正在使用交叉熵作为损失函数。我想知道对每个对象进行最有把握的预测的最佳方法是什么。我应该首先计算LogSoftMax并以最高的概率(在这两个预测数组中)选择该类,还是应该直接采用logit?

1 个答案:

答案 0 :(得分:1)

由于LogSoftMax保留顺序,因此最大logit将始终对应于最高置信度。因此,如果您感兴趣的只是找到最有信心的类别的索引,则无需执行操作。

获取最自信类别的索引的最简单方法可能是使用torch.argmax

例如

batch_size = 5
num_logits = 10
y = torch.randn(batch_size, num_logits)
preds = torch.argmax(y, dim=1)

在这种情况下会导致

>>> print(preds)
tensor([9, 7, 2, 4, 6])