用Pytorch进行多类分类

时间:2020-03-30 20:02:03

标签: pytorch loss-function softmax cross-entropy

我是Pytorch的新手,我需要澄清多类分类。

我正在微调DenseNet神经网络,因此它可以识别3个不同的类。

因为这是一个多类问题,所以我必须以这种方式替换分类层:

kernelCount = self.densenet121.classifier.in_features
self.densenet121.classifier = nn.Sequential(nn.Linear(kernelCount, 3), nn.Softmax(dim=1))

并使用CrossEntropyLoss作为损失函数:

loss = torch.nn.CrossEntropyLoss(reduction='mean')

通过在Pytorch论坛上阅读,我发现CrossEntropyLoss将softmax函数应用于神经网络的输出。这是真的?我应该从网络结构中删除Softmax激活功能吗?

那测试阶段呢?如果包括在内,我必须在模型的输出上调用softmax函数吗?

预先感谢您的帮助。

1 个答案:

答案 0 :(得分:1)

是的,CrossEntropyLoss隐含地应用softmax。您应该在网络的末端删除softmax层,因为softmax不是幂等的,因此两次应用它会造成语义错误。

就评估/测试而言。请记住,softmax是单调递增的操作(意味着应用时输出的相对顺序不会改变)。因此,在softmax之前和之后的argmax结果将得到相同的结果。

唯一可能要在评估期间显式执行softmax的情况是,由于某种原因需要实际的置信度值。如果需要,您可以在评估期间使用torch.softmax在网络输出上显式应用softmax。