使用pytorch中的函数计算二进制熵损失

时间:2019-11-14 01:35:22

标签: python deep-learning pytorch

我对计算二进制交叉熵有问题。我知道在pytorch中工作的方式是:

import torch
import torch.nn as nn
import torch.nn.functional as F
def lossfunc():
    return F.binary_cross_entropy

criterion = lossFunc()
input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = criterion(torch.sigmoid(input),target)

但是如何以这种方式完成lossfunc(),因为我不知道如何将参数传递给函数:

#the function that add sigmoid to input and calculate the binary cross entropy loss
def lossfunc():
   return

criterion = lossFunc()
input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = criterion(input,target)

1 个答案:

答案 0 :(得分:0)

我认为您正在将nn API与功能性F API混淆。在功能性api中,损失函数F.binary_cross_entropy可以直接用作函数。

nn api中,您需要创建一个损耗类的对象,例如criterion = nn.BCELoss()

因此,您只需执行以下操作即可:

def lossFunc(input, target):
   return F.binary_cross_entropy(torch.sigmoid(input),target)

input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = lossFunc(input,target)

PyTorch还提供了nn.nn.BCEWithLogitsLoss()F.binary_cross_entropy_with_logits(),它们结合了S型和二进制交叉熵。