如何有效地计算PyTorch中的混淆矩阵?

时间:2018-05-30 09:29:56

标签: pytorch

我有一个包含我的预测的张量和一个包含我的二进制分类问题的实际标签的张量。如何有效地计算混淆矩阵?

1 个答案:

答案 0 :(得分:0)

在使用for-loop的第一个版本证明效率低下之后,这是我迄今为止提出的最快的解决方案,对于两个等维张量predictiontruth

def confusion(prediction, truth):
    confusion_vector = prediction / truth

    true_positives = torch.sum(confusion_vector == 1).item()
    false_positives = torch.sum(confusion_vector == float('inf')).item()
    true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
    false_negatives = torch.sum(confusion_vector == 0).item()

    return true_positives, false_positives, true_negatives, false_negatives

https://gist.github.com/the-bass/cae9f3976866776dea17a5049013258d

上的评论版和测试用例