Pytorch中的许多损失功能都在nn.modules.loss和nn.functional中实现。
例如,下面的两行返回相同的结果。
import torch.nn as nn
import torch.functional as F
nn.L1Loss()(x,y)
F.l1_loss(x,y)
为什么有两个实现?
答案 0 :(得分:2)
我认为这是部分应用程序的情况-能够将许多配置变量与损失函数对象“捆绑”在一起很有用。在大多数情况下,损失函数必须以prediction
和ground_truth
作为参数。这使得损失函数具有相当统一的基本API。但是,它们在细节上有所不同。例如,并非每个损失函数都有一个reduction
参数。 BCEWithLogitsLoss
具有weight
和pos_weight
参数; PoissonNLLLoss
具有log_input
,eps
。编写类似
def one_epoch(model, dataset, loss_fn, optimizer):
for x, y in dataset:
model.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
可以与实例化的BCEWithLogitsLoss
和PoissonNLLLoss
一起使用。但是,由于需要进行簿记,因此它无法与职能部门合作。相反,您必须先创建
loss_fn_packed = functools.partial(F.binary_cross_entropy_with_logits, weight=my_weight, reduction='sum')
,然后您才能将其与上面定义的one_epoch
一起使用。但是,这种包装已经随面向对象的损失API一起提供了,还带有一些麻烦(由于损失子类nn.Module
,因此您可以使用向前和向后钩子,在cpu和gpu之间移动内容,等等)。>
答案 1 :(得分:1)
有没有doc的BCEWithLogistsLoss的代码:
class BCEWithLogitsLoss(_Loss):
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean',
pos_weight: Optional[Tensor] = None) -> None:
super(BCEWithLogitsLoss, self).__init__(size_average, reduce, reduction)
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
def forward(self, input: Tensor, target: Tensor) -> Tensor:
return F.binary_cross_entropy_with_logits(input, target,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
如果不考虑参数传递,类和函数的实现完全一样。但是,使用类实现可以使您的代码更加简洁和可读,例如
使用函数
loss_func=binary_cross_entropy_with_logits
def train(model, dataloader, loss_fn, optimizer, weight, size_average, reduce, reduction, pos_weight):
for x, y in dataloader:
model.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y, weight, size_average, reduce, reduction, pos_weight)
loss.backward()
optimizer.step()
使用类
loss_func = BCEWithLogitsLoss(weight, size_average, reduce, reduction, pos_weight)
def train(model, dataloader, loss_fn, optimizer):
for x, y in dataloader:
model.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
如果你有多个参数或不同的损失函数,类的实现会更好。