如何解释Scikit-learn混淆矩阵

时间:2014-04-25 17:03:14

标签: python machine-learning scikit-learn

我正在使用confusion matrix来检查分类器的性能。

我正在使用Scikit-Learn,我有点困惑。我怎样才能解释

的结果
from sklearn.metrics import confusion_matrix
>>> y_true = [2, 0, 2, 2, 0, 1]
>>> y_pred = [0, 0, 2, 2, 0, 2]
>>> confusion_matrix(y_true, y_pred)
array([[2, 0, 0],
       [0, 0, 1],
       [1, 0, 2]])

我如何判断这些预测值是好还是不好。

1 个答案:

答案 0 :(得分:1)

判断分类器好坏的最简单方法就是使用一些标准错误度量(例如Mean squared error)来计算错误。我想你的例子是从Scikit的documentation复制的,所以我假设你已经阅读了这个定义。

我们在这里有三个课程:012。在对角线上,混淆矩阵告诉您,特定类的预测频率是多少。所以从对角线2 0 2我们可以说索引0的类被正确分类了2次,索引1的类从未被正确预测过,而索引为2的类是正确预测了2次。

在对角线下方和上方,您有数字告诉您索引等于元素行数的类被分类为索引等于矩阵列的类的次数。例如,如果您查看第一列,则在对角线下方:0 1(在矩阵的左下角)。较低的1告诉您,索引2(最后一行)的类曾被错误地归类为0(第一列)。这相当于您的y_true中有一个标签为2且被归类为0的样本。这发生在第一个样本中。

如果您对混淆矩阵中的所有数字求和,则得到测试样本的数量(2 + 2 + 1 + 1 = 6 - 等于y_truey_pred的长度。如果您对行进行求和,则会获得每个标签的样本数:正如您可以验证的那样,0中确实存在两个1,一个2和三个y_pred }。

例如,如果你用这个数字划分矩阵元素,你可以说,例如,标签2的类被正确识别,准确率达到66%,而在1/3的情况下,它被混淆了(因此名称)带有标签为0的类。

<强> TL; DR:

虽然单数误差测量可以衡量整体性能,但是使用混淆矩阵可以确定是否(某些例子):

  • 你的分类器只是糟透了一切

  • 或者它可以很好地处理某些类,而有些类则没有(这会给你一个提示来查看数据的这个特定部分并观察分类器对这些情况的行为)

  • 它做得很好,但经常混淆标签A和B.例如,对于线性分类器,您可能需要检查,如果这些类是线性可分的。