有没有人知道如何计算语义分割中的前k个准确性?在分类中,我们可以将topk准确性计算为:
correct = output.eq(gt.view(1, -1).expand_as(output))
非常感谢!
答案 0 :(得分:1)
您正在寻找torch.topk
函数,该函数可计算一个维度上的前k个值。
torch.topk
的第二个输出是“ arg top k”:最大值的k个索引。
答案 1 :(得分:0)
假设您的输出是一系列按照您的课程列表排序的分数labels
:
import torch
scores, indices = torch.topk(output, k)
correct = labels[indices]