我正在训练CNN进行图像分类。相同的对象(然后带有相同的标签)在测试集中出现两次(如两个视点)。我想在上课时充分利用这一点。
现在,最后一层是Linear
层(PyTorch),我正在使用交叉熵作为损失函数。我想知道对每个对象进行最有把握的预测的最佳方法是什么。我应该首先计算LogSoftMax
并以最高的概率(在这两个预测数组中)选择该类,还是应该直接采用logit?
答案 0 :(得分:1)
由于LogSoftMax
保留顺序,因此最大logit将始终对应于最高置信度。因此,如果您感兴趣的只是找到最有信心的类别的索引,则无需执行操作。
获取最自信类别的索引的最简单方法可能是使用torch.argmax
。
例如
batch_size = 5
num_logits = 10
y = torch.randn(batch_size, num_logits)
preds = torch.argmax(y, dim=1)
在这种情况下会导致
>>> print(preds)
tensor([9, 7, 2, 4, 6])