如何计算3D图像火炬的交叉熵?

时间:2018-12-13 10:38:59

标签: python pytorch

See the figure here

左边是(2,480,640),它是softmax值

正确的是(2,480,640),它是一键编码值

如何获得所有元素的交叉熵损失?

1 个答案:

答案 0 :(得分:0)

与其他任何图像完全相同。使用binary_cross_entropy(left, right)。请注意

  1. 两者都必须为torch.float32 dtype,因此您可能需要先使用right来转换right.to(torch.float32)
  2. 如果您的left张量包含logit而不是概率,则调用binary_cross_entropy_with_logits(left, right)胜过调用binary_cross_entropy(torch.sigmoid(left), right)