在 scikit-learn 决策树中,如何识别导致错误分类的决策?

时间:2021-03-16 13:04:09

标签: python machine-learning scikit-learn decision-tree

我正在尝试为数据集创建决策树并研究由此产生的混淆矩阵。虽然混淆矩阵告诉我发生了多少错误分类,但它并不能准确告诉我 X_train 中的哪些特定实例被错误分类。我试图找出哪些是这些错误分类的实例以及它们最终在哪个叶节点中。我知道我可以使用 decision_path() 但它没有告诉我该特定实例是否被错误分类。我在这里的主要目标是确定混淆和错误分类的实例在哪里结束。以下是我的代码:

from sklearn.datasets import load_iris
iris=load_iris()

Y_train=iris.target
X_train=iris.data

clf=tree.DecisionTreeClassifier( max_depth=3, criterion='entropy')
clf.fit(X_train, Y_train)
pred=clf.predict(X_train)
print('Accuracy on test data is %.2f' % (accuracy_score(Y_train, pred)))

1 个答案:

答案 0 :(得分:1)

您在 pred 中获得了所有预测,在 Y_train 中获得了所有训练值

您错误分类的预测结果就是 pred[pred!=Y_train]

如果您想要功能 X_train[pred!=Y_train]