我正在阅读Pytorch官方教程进行微调,我面临一个问题,那就是计算每个时期的损耗。
在此之前,我计算一批数据的损失,累积这些批损失,并找到这些值的平均值作为历时损失。但在该示例中,计算如下:
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
我的问题在running_loss += loss.item() * inputs.size(0)
行中。它是巴赫大小的批次乘以损失值。计算历元损失的真正方法是什么?
什么是损失单位?损失值的范围是多少?
答案 0 :(得分:1)
是的,代码段添加了批量大小与批量平均误差的乘积。如果要计算真实的总和。您可以使用
torch.nn.CrossEntropyLoss(reduction = "sum")
这将为您提供批次的错误总和。然后,您可以按如下所示直接为每个批次求和:
running_loss += loss.item()
损失值的范围取决于类的数量和特征向量。如果您使用reduction="sum"
,则问题中的代码将具有相同的running_loss,因为您的代码基本上可以做到
(loss/batch_size) * batch_size
与损失值相同。但是,反向传播会发生变化,因为一方面您是根据损失之和进行反向传播,另一方面您是根据平均损失来计算反向传播。