如何使用sklearn.tree输出进行预测的节点?

时间:2015-01-02 11:30:01

标签: python scikit-learn regression

我想获取使用sklearn.tree进行预测的节点的所有信息。

例如:

from sklearn.datasets import load_iris
nfrom sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)

现在我们可以使用以下方法预测班级:

clf.predict(iris.data[0, :])

如何获得进行预测的叶节点以及存储在叶子中的信息?

我知道上面例子的树的图形表示如下:

http://scikit-learn.org/stable/modules/tree.html#tree-classification

所以我知道对应于输入 iris.data [0,:] (第一个左子)的节点具有以下统计信息:

  • 误差= 0
  • 样品= 50
  • value = [50 0 0]

是否可以在不打印树的情况下自动获取输出节点和(上面)信息?根据我目前的不足,关键是获取进行预测的叶节点的 ID ,相关的统计数据随后包含在 clf.tree_.value [ID] clf.tree_.n_samples [ID]

谢谢

1 个答案:

答案 0 :(得分:2)

查看this问题。它说如何获得叶子的ID。然后,您可以使用clf.tree_.valueclf.tree.n_samples