如何使用MxNet metrics api计算带有矢量标签的多类逻辑回归分类器的准确性? 以下是标签的示例:
arguments
使用此函数的天真方式会产生错误的结果,因为argmax会将模型输出压缩成具有最大概率值的索引
Class1: [1,0,0,0]
Class2: [0,1,0,0]
Class3: [0,0,1,0]
Class4: [0,0,0,1]
我目前的解决方案很少有问题:
def evaluate_accuracy(data_iterator, ctx, net):
acc = mx.metric.Accuracy()
for i, (data, label) in enumerate(data_iterator):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
out = net(data)
p = nd.argmax(out, axis=1)
acc.update(preds=p, labels=label)
return acc.get()[1]
答案 0 :(得分:1)
Accuracy指标非常棘手。它并不适用于单热编码标签作为基本事实。
我发现这有点违反直觉,但你需要传递非单热编码标签作为基础事实,但实际的类(例如,2而不是[0,0,1,0])。否则,准确性将无法以您期望的方式发挥作用。请查看我之前的回复 - Why MXNet is reporting the incorrect validation accuracy?
此外,MxNet期望类以0开头。因此,如果您的类从1开始,那么您需要通过减去1来调整所有类。