混淆矩阵不支持多标签指示符

时间:2017-10-26 12:09:11

标签: python numpy scikit-learn classification

multilabel-indicator is not supported是我在尝试运行时收到的错误消息:

confusion_matrix(y_test, predictions)

y_test是一个DataFrame形状:

Horse | Dog | Cat
1       0     0
0       1     0
0       1     0
...     ...   ...

predictionsnumpy array

[[1, 0, 0],
 [0, 1, 0],
 [0, 1, 0]]

我已经搜索了一些错误消息,但我还没找到可以申请的内容。任何提示?

3 个答案:

答案 0 :(得分:29)

不,您对confusion_matrix的输入必须是预测列表,而不是OHE(一个热编码)。在argmaxy_test上致电y_pred,您应该得到您期望的结果。

confusion_matrix(
    y_test.values.argmax(axis=1), predictions.argmax(axis=1))

array([[1, 0],
       [0, 2]])

答案 1 :(得分:6)

混淆矩阵采用标签矢量(不是单热编码)。你应该运行

confusion_matrix(y_test.values.argmax(axis=1), predictions.argmax(axis=1))

答案 2 :(得分:0)

from sklearn.metrics import confusion_matrix

predictions_one_hot = model.predict(test_data)
cm = confusion_matrix(labels_one_hot.argmax(axis=1), predictions_one_hot.argmax(axis=1))
print(cm)

输出将如下所示:

[[298   2  47  15  77   3  49]
 [ 14  31   2   0   5   1   2]
 [ 64   5 262  22  94  38  43]
 [ 16   1  20 779  15  14  34]
 [ 49   0  71  33 316   7 118]
 [ 14   0  42  23   5 323   9]
 [ 20   1  27  32  97  13 436]]