sklearn.metrics.classification_report支持字段显示标签的错误编号

时间:2018-05-03 14:29:30

标签: python scikit-learn classification text-classification multiclass-classification

我正在使用sklearn.metrics.classification_report来评估我的分类结果。

y_pred = np.argmax(model.predict(X_test), axis=1)
y_true = np.argmax(y_test, axis=1)
print(classification_report(y_true, y_pred, target_names=list(le.classes_)))

这是我的结果:

              precision    recall  f1-score   support

 Technology       0.00      0.00      0.00         1
 Travel           0.00      0.00      0.00         5
 Fashion          0.00      0.00      0.00        25
 Entertainment    0.72      1.00      0.84       130
 Art              0.00      0.00      0.00         7
 Politic          0.00      0.00      0.00        12

 avg / total       0.52      0.72      0.61       180

问题实际上我有7个标签。订单包括技术,旅游,时尚,娱乐,艺术,政治,体育。实际上我的y_true结果中没有任何艺术品标签,但报告按顺序列出,所以它列出了Art但跳过了Sports。它写出了Politic for Art的结果,而Sports的结果则出现在Politic的行中。

为什么不跳过艺术?我不知道如何解决这个问题。

1 个答案:

答案 0 :(得分:1)

分类报告中的索引标签是参数“target_names”的值。请确保您为该参数提供了正确的值。

根据您的要求输出,您应该有 target_names = [“技术”,“旅行”,“时尚”,“娱乐”,“政治”,“体育”]

我建议,请检查“le.classes_”的输出,我不确定哪个变压器'le'指的是。