我想获取使用sklearn.tree进行预测的节点的所有信息。
例如:
from sklearn.datasets import load_iris
nfrom sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
现在我们可以使用以下方法预测班级:
clf.predict(iris.data[0, :])
如何获得进行预测的叶节点以及存储在叶子中的信息?
我知道上面例子的树的图形表示如下:
http://scikit-learn.org/stable/modules/tree.html#tree-classification
所以我知道对应于输入 iris.data [0,:] (第一个左子)的节点具有以下统计信息:
是否可以在不打印树的情况下自动获取输出节点和(上面)信息?根据我目前的不足,关键是获取进行预测的叶节点的 ID ,相关的统计数据随后包含在 clf.tree_.value [ID] 中 clf.tree_.n_samples [ID]
谢谢