标签是列表时的PyTorch损失函数?

时间:2019-03-14 07:53:46

标签: neural-network deep-learning pytorch

[帮助朋友发帖] 我有一个模型,该模型返回长度为k的预测的二进制序列,例如[0, 0.2, 0.6, 0.4, 0.8],并且我有类似[0, 1, 1, 0, 0]的标签。我如何在这里定义损失函数?

1 个答案:

答案 0 :(得分:0)

如果它是二进制分类,并且预测张量来自S形函数,则可以使用torch.nn.BCELoss二进制交叉熵损失。如果您没有将Sigmoid / softmax应用于预测张量,则最好使用torch.nn.BCEWithLogitsLoss

在PyTorch中,损失函数称为标准,您可以将它们定义为:

criterion = nn.BCELoss()
loss = criterion(prediction, target)