将scikit-learn DecisionTreeClassifier.tree_.value映射到预测类

时间:2014-10-05 21:33:42

标签: python scikit-learn decision-tree

我在3类数据集上使用了scikit-learn DecissionTreeClassifier。在我拟合分类器后,我访问tree_属性上的所有叶节点,以获得最终在每个类的给定节点中的实例数量。

clf = tree.DecisionTreeClassifier(max_depth=5)
clf.fit(X, y)
# lets assume there is a leaf node with id 5
print clf.tree_.value[5]

这将打印出来:

>>> array([[  0.,   1.,  68.]])

但......我怎么知道该数组中哪个位置属于哪个类? 分类器有一个classes_属性,它也是一个列表

>>> clf.classes_
array(['CLASS_1', 'CLASS_2', 'CLASS_3'], dtype=object)

值数组上的索引1可能与类数组的索引1上的类匹配,依此类推?

2 个答案:

答案 0 :(得分:7)

在scikit-learn邮件列表中询问此问题,我的猜测是正确的。结果是,value数组上的索引1与classes数组的索引1上的类匹配,依此类推

答案 1 :(得分:0)

不,它不是clf.classes_,而是包含X的列索引的clf.tree_.feature。如果X是Pandas DataFrame,则X.columns包含名称。您可以在a similar question中找到更详细的信息。