我对计算二进制交叉熵有问题。我知道在pytorch中工作的方式是:
import torch
import torch.nn as nn
import torch.nn.functional as F
def lossfunc():
return F.binary_cross_entropy
criterion = lossFunc()
input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = criterion(torch.sigmoid(input),target)
但是如何以这种方式完成lossfunc(),因为我不知道如何将参数传递给函数:
#the function that add sigmoid to input and calculate the binary cross entropy loss
def lossfunc():
return
criterion = lossFunc()
input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = criterion(input,target)
答案 0 :(得分:0)
我认为您正在将nn
API与功能性F
API混淆。在功能性api中,损失函数F.binary_cross_entropy
可以直接用作函数。
在nn
api中,您需要创建一个损耗类的对象,例如criterion = nn.BCELoss()
因此,您只需执行以下操作即可:
def lossFunc(input, target):
return F.binary_cross_entropy(torch.sigmoid(input),target)
input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = lossFunc(input,target)
PyTorch还提供了nn.nn.BCEWithLogitsLoss()
和F.binary_cross_entropy_with_logits()
,它们结合了S型和二进制交叉熵。