Pytorch - 如何使用BCE损失

时间:2017-04-30 16:33:08

标签: torch autoencoder loss pytorch

我想在pytorch中编写一个简单的自动编码器并使用BCELoss,但是我得到了NaN,因为它期望目标在0和1之间。有人会发布一个简单的BCELoss用例吗?

2 个答案:

答案 0 :(得分:15)

  

BCELoss功能在数值上并不稳定。看到这个问题:   https://github.com/pytorch/pytorch/issues/751

编辑: 此问题已通过Pull #1792解决,因此BCELoss现在数字稳定。

如果你从源代码构建pytorch,你可以使用数值稳定的函数BCEWithLogitsLoss(在https://github.com/pytorch/pytorch/pull/1792中提供),它将logit作为输入。

否则,您可以使用以下功能(在上述问题中由yzgao提供):

class StableBCELoss(nn.modules.Module):
       def __init__(self):
             super(StableBCELoss, self).__init__()
       def forward(self, input, target):
             neg_abs = - input.abs()
             loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
             return loss.mean()

答案 1 :(得分:3)

您可能希望在网络末端使用sigmoid层。以这种方式,数字代表概率。还要确保目标是二进制数。如果您发布完整的代码,我们可能会提供更多帮助。