获得多类分类中最差的预测类

时间:2019-08-19 08:45:25

标签: python scikit-learn confusion-matrix

我正在研究一个多类分类问题,在那里我有很多不同的类(50多个)。

问题是,我想突出显示最差的预测类别(例如,在混淆矩阵等中),以便对我的分类器进行进一步的调整。

我的预测和测试数据保存在一个列表中(来自sklearn的一个小示例):

y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
confusion_matrix(y_true, y_pred)
array([[2, 0, 0],
       [0, 0, 1],
       [1, 0, 2]])

在该示例中,如何从矩阵中获得1类?那里的预测是完全错误的。有没有一种方法可以根据他们的真实正面预测对他们进行分类?

2 个答案:

答案 0 :(得分:1)

您可以使用scikit-learn中的classifiction_report,它将返回具有精度,召回率和F分数的字典。然后,您可以按排序方式打印字典,以便轻松查看最差的预测类别。

#prints classification_report     
print(classification_report(y_true, y_pred)

#returns a dict, which you can easily sort by prediction
report = classification_report(y_true, y_pred, output_dict=True)

答案 1 :(得分:1)

您可以为此使用简单的功能:

def print_class_accuracies(confusion_matrix):
    # get the number of occurrences for each class
    counts = {cl: y_true.count(cl) for cl in set(y_true)}
    # extract the diagonal values (true positives)
    tps = dict(enumerate(conf.diagonal()))
    # Get the accuracy for each class, preventing ZeroDivisionErrors
    pred_accuracy = {cl: tps[cl]/counts.get(cl, 1) for cl in tps}
    # Get a ranking, worst accuracies are first/lowest
    ranking = sorted([(acc,cl) for cl, acc in pred_accuracy.items()])
    # Pretty print it
    for acc, cl in ranking:
        print(f"Class {cl}: accuracy: {acc:.2f}")