pythom中多类随机森林的打印混淆矩阵

时间:2020-07-28 13:52:17

标签: python pandas machine-learning scikit-learn random-forest

我正在使用sklearn.RandomForestClassifier,并且我有11个课程。我的数据在数据框中,所有变量都经过热编码。这些类是字符串,例如“土豆”,“西红柿”,“ Straberry”等。

当我尝试打印混淆矩阵时,得到以下信息:

print(pd.crosstab(y_test, y_pred))

Error: If using all scalar values, you must pass an index

当我传递索引时:

print(pd.crosstab(y_test, y_pred, index = [0]))

Error:crosstab() got multiple values for argument 'index'

解决这个问题的最佳方法是什么?

1 个答案:

答案 0 :(得分:0)

该错误表明您需要将参数“ index”传递给交叉表,而不是一个可以帮助您遍历列表的索引。您可以找到正确的方法以及有关它的更多详细信息here

您还可以使用以下代码在Sci-Kit Learn中绘制混淆矩阵

此代码从用于混淆矩阵的训练数据中获取所有标签

label=y_train.unique()
label=np.sort(label)

此代码导入混淆矩阵并将其绘制。 plt.cm.Blues用于配色方案,clf是您的分类器,请确保使用您命名的分类器对其进行更改。

from sklearn.metrics import plot_confusion_matrix
cm=plot_confusion_matrix(clf,X_test, y_test,labels=label,cmap=plt.cm.Blues)