如何从scikit-learn获取混淆矩阵的行/列标签?

时间:2019-02-14 02:41:21

标签: python-3.x scikit-learn

如果我在创建矩阵时最初没有指定它们,那么如何确认输出的混淆矩阵的列/行,如以下代码所示:

    y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
    y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]
    cm=confusion_matrix(y_true, y_pred)

    array([[2, 0, 0],
           [0, 0, 1],
           [1, 0, 2]])

从文档中我知道它说If none is given, those that appear at least once in y_true or y_pred are used in sorted order,所以我认为列/行将为("ant", "bird", "cat"),但是如何确认呢? 我尝试了类似cm.labels的方法,但是没有用。

1 个答案:

答案 0 :(得分:1)

source code of the confusion_matrix中:

if labels is None:
    labels = unique_labels(y_true, y_pred)

unique_labels是什么,它是从哪里导入的?

from sklearn.utils.multiclass import unique_labels
unique_labels(y_true, y_pred)

返回

array(['ant', 'bird', 'cat'],
      dtype='<U4')

unique_labels extracts an ordered array的唯一标签。

示例:

>>> from sklearn.utils.multiclass import unique_labels
>>> unique_labels([3, 5, 5, 5, 7, 7])
array([3, 5, 7])
>>> unique_labels([1, 2, 3, 4], [2, 2, 3, 4])
array([1, 2, 3, 4])
>>> unique_labels([1, 2, 10], [5, 11])
array([ 1,  2,  5, 10, 11])

也许是一个更直观的示例:

unique_labels(['z', 'x', 'y'], ['a', 'z', 'c'], ['e', 'd', 'y'])

返回:

array(['a', 'c', 'd', 'e', 'x', 'y', 'z'],
      dtype='<U1')