PyTorch中的运行损耗是什么以及如何计算

时间:2020-04-08 02:43:04

标签: python deep-learning pytorch torch torchvision

我查看了PyTorch文档中的this教程,以了解迁移学习。我只有一条线听不懂。

使用loss = criterion(outputs, labels)计算损耗后,使用running_loss += loss.item() * inputs.size(0)计算运行损耗,最后,使用running_loss / dataset_sizes[phase]计算历时损耗。

不是loss.item()是整个小批量产品(如果我错了,请纠正我)。也就是说,如果batch_size为4,则loss.item()会损失整个4张图像。如果是这样,为什么在计算loss.item()时将inputs.size(0)running_loss相乘?在这种情况下,这一步骤难道不是多余的乘法吗?

任何帮助将不胜感激。谢谢!

2 个答案:

答案 0 :(得分:6)

这是因为CrossEntropy或其他损失函数给定的损失除以元素数,即默认情况下减少参数为mean

torch.nn.CrossEntropyLoss(weight = None,size_average = None,ignore_index = -100,reduce = None,reduction ='mean')

因此,loss.item()包含整个微型批次的损失,但除以批次大小。这就是在计算loss.item()时将inputs.size(0)乘以running_loss给定的批量大小的原因。

答案 1 :(得分:2)

如果batch_size为4,则loss.item()将给出整个4张图像的损失

这取决于loss的计算方式。请记住,loss是一个张量,就像其他所有张量一样。通常,PyTorch API默认会返回avg loss

“对于每个小批量,损失是通过观察得出的平均值。”

张量t.item()

t仅将其转换为python的默认float32。

更重要的是,如果您是PyTorch的新手,那么知道我们使用t.item()而不是t来保持运行损耗可能会有所帮助,因为PyTorch张量存储了其值的历史记录,很快会使您的GPU过载。