我在3类数据集上使用了scikit-learn DecissionTreeClassifier。在我拟合分类器后,我访问tree_属性上的所有叶节点,以获得最终在每个类的给定节点中的实例数量。
clf = tree.DecisionTreeClassifier(max_depth=5)
clf.fit(X, y)
# lets assume there is a leaf node with id 5
print clf.tree_.value[5]
这将打印出来:
>>> array([[ 0., 1., 68.]])
但......我怎么知道该数组中哪个位置属于哪个类? 分类器有一个classes_属性,它也是一个列表
>>> clf.classes_
array(['CLASS_1', 'CLASS_2', 'CLASS_3'], dtype=object)
值数组上的索引1可能与类数组的索引1上的类匹配,依此类推?
答案 0 :(得分:7)
在scikit-learn邮件列表中询问此问题,我的猜测是正确的。结果是,value数组上的索引1与classes数组的索引1上的类匹配,依此类推
答案 1 :(得分:0)
不,它不是clf.classes_,而是包含X的列索引的clf.tree_.feature。如果X是Pandas DataFrame,则X.columns包含名称。您可以在a similar question中找到更详细的信息。