我在我自己的数据集上使用更快的rcnn(mxnet)进行对象检测,该数据集有9个类(包括背景)。但是,我发现最终它只打出了训练过程中所有9个班级的平均准确度。此外,在测试过程中,它还只打印出所有9个类的平均精度和召回率。我想知道如何在培训过程中打印出每个班级的准确性,以及每个班级在考试过程中的召回和精确度? 或者有人可以告诉我,我应该在哪里看到我的目标? 一个理想的例子将在图像中显示。 enter image description here
答案 0 :(得分:1)
您可以使用Scikit-learn函数sklearn.metrics.precision_recall_fscore_support¶
。并且sklearn.metrics.classification_report
用于美化版本。
在测试时,您将拥有一个真值数组(Y_true
)和每个类(Y_prob
)的预测概率数组。使用如下:
Y_pred = np.argmax(Y_prob, axis=1)
print(classification_report(Y_true, Y_pred))
precision recall f1-score support
class 0 0.50 1.00 0.67 1
class 1 0.00 0.00 0.00 1
class 2 1.00 0.67 0.80 3
avg / total 0.70 0.60 0.61 5
在培训时间每N批次需要更多的工作。如果您使用的是eval_metric
方法,则可以设置回调参数和自定义module.fit
;
model = mx.mod.Module(symbol=...)
model.fit(..., batch_end_callback = mx.callback.Speedometer(batch_size),
eval_metric=custom_metric, ...)
您需要为扩展custom_metric
的{{1}}创建一个新类,并实现一个打印(或甚至返回)每个类指标的mxnet.metric.EvalMetric
方法。