我正在研究一个多类分类问题,在那里我有很多不同的类(50多个)。
问题是,我想突出显示最差的预测类别(例如,在混淆矩阵等中),以便对我的分类器进行进一步的调整。
我的预测和测试数据保存在一个列表中(来自sklearn的一个小示例):
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
confusion_matrix(y_true, y_pred)
array([[2, 0, 0],
[0, 0, 1],
[1, 0, 2]])
在该示例中,如何从矩阵中获得1类?那里的预测是完全错误的。有没有一种方法可以根据他们的真实正面预测对他们进行分类?
答案 0 :(得分:1)
您可以使用scikit-learn中的classifiction_report,它将返回具有精度,召回率和F分数的字典。然后,您可以按排序方式打印字典,以便轻松查看最差的预测类别。
#prints classification_report
print(classification_report(y_true, y_pred)
#returns a dict, which you can easily sort by prediction
report = classification_report(y_true, y_pred, output_dict=True)
答案 1 :(得分:1)
您可以为此使用简单的功能:
def print_class_accuracies(confusion_matrix):
# get the number of occurrences for each class
counts = {cl: y_true.count(cl) for cl in set(y_true)}
# extract the diagonal values (true positives)
tps = dict(enumerate(conf.diagonal()))
# Get the accuracy for each class, preventing ZeroDivisionErrors
pred_accuracy = {cl: tps[cl]/counts.get(cl, 1) for cl in tps}
# Get a ranking, worst accuracies are first/lowest
ranking = sorted([(acc,cl) for cl, acc in pred_accuracy.items()])
# Pretty print it
for acc, cl in ranking:
print(f"Class {cl}: accuracy: {acc:.2f}")