如何使用Mxnet Faster RCNN打印出每个类的平均精度以进行对象检测

时间:2017-11-30 08:43:36

标签: computer-vision deep-learning object-detection mxnet

我在我自己的数据集上使用更快的rcnn(mxnet)进行对象检测,该数据集有9个类(包括背景)。但是,我发现最终它只打出了训练过程中所有9个班级的平均准确度。此外,在测试过程中,它还只打印出所有9个类的平均精度和召回率。我想知道如何在培训过程中打印出每个班级的准确性,以及每个班级在考试过程中的召回和精确度? 或者有人可以告诉我,我应该在哪里看到我的目标? 一个理想的例子将在图像中显示。 enter image description here

1 个答案:

答案 0 :(得分:1)

您可以使用Scikit-learn函数sklearn.metrics.precision_recall_fscore_support¶。并且sklearn.metrics.classification_report用于美化版本。

在测试时,您将拥有一个真值数组(Y_true)和每个类(Y_prob)的预测概率数组。使用如下:

Y_pred = np.argmax(Y_prob, axis=1)
print(classification_report(Y_true, Y_pred))

         precision    recall  f1-score   support
class 0       0.50      1.00      0.67         1
class 1       0.00      0.00      0.00         1
class 2       1.00      0.67      0.80         3
avg / total   0.70      0.60      0.61         5

在培训时间每N批次需要更多的工作。如果您使用的是eval_metric方法,则可以设置回调参数和自定义module.fit;

model = mx.mod.Module(symbol=...)
model.fit(..., batch_end_callback = mx.callback.Speedometer(batch_size),
          eval_metric=custom_metric, ...)

您需要为扩展custom_metric的{​​{1}}创建一个新类,并实现一个打印(或甚至返回)每个类指标的mxnet.metric.EvalMetric方法。