multilabel-indicator is not supported
是我在尝试运行时收到的错误消息:
confusion_matrix(y_test, predictions)
y_test
是一个DataFrame
形状:
Horse | Dog | Cat
1 0 0
0 1 0
0 1 0
... ... ...
predictions
是numpy array
:
[[1, 0, 0],
[0, 1, 0],
[0, 1, 0]]
我已经搜索了一些错误消息,但我还没找到可以申请的内容。任何提示?
答案 0 :(得分:29)
不,您对confusion_matrix
的输入必须是预测列表,而不是OHE(一个热编码)。在argmax
和y_test
上致电y_pred
,您应该得到您期望的结果。
confusion_matrix(
y_test.values.argmax(axis=1), predictions.argmax(axis=1))
array([[1, 0],
[0, 2]])
答案 1 :(得分:6)
混淆矩阵采用标签矢量(不是单热编码)。你应该运行
confusion_matrix(y_test.values.argmax(axis=1), predictions.argmax(axis=1))
答案 2 :(得分:0)
from sklearn.metrics import confusion_matrix
predictions_one_hot = model.predict(test_data)
cm = confusion_matrix(labels_one_hot.argmax(axis=1), predictions_one_hot.argmax(axis=1))
print(cm)
输出将如下所示:
[[298 2 47 15 77 3 49]
[ 14 31 2 0 5 1 2]
[ 64 5 262 22 94 38 43]
[ 16 1 20 779 15 14 34]
[ 49 0 71 33 316 7 118]
[ 14 0 42 23 5 323 9]
[ 20 1 27 32 97 13 436]]