sklearn GBDT中的叶子值是什么,我如何获得它们?

时间:2017-11-24 01:44:57

标签: tree scikit-learn boosting

我可以将GBDT的结构导出到the tree.export_graphviz function的图像:

```Python3

from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.ensemble import GradientBoostingClassifier

clf = GradientBoostingClassifier(n_estimators=1) # set to 1 for the sake of simplicity
iris = load_iris()

clf = clf.fit(iris.data, iris.target)
tree.export_graphviz(clf.estimators_[0,0], out_file='tree.dot')
check_call(['dot','-Tpng','tree.dot','-o','tree.png'])

```

This is the obtained image.

我想知道叶子上的value是什么?我怎样才能获得它们?

我尝试了applydecision_function功能,但都无效。

1 个答案:

答案 0 :(得分:0)

您可以使用其内部对象tree_及其属性访问每个树的离开属性。 export_graphviz正好使用这种方法。

考虑这段代码。对于每个属性,它在所有树节点上提供其值的数组:

print(clf.estimators_[0,0].tree_.feature)
print(clf.estimators_[0,0].tree_.threshold)
print(clf.estimators_[0,0].tree_.children_left)
print(clf.estimators_[0,0].tree_.children_right)
print(clf.estimators_[0,0].tree_.n_node_samples)
print(clf.estimators_[0,0].tree_.value.ravel())

输出

[ 2 -2 -2]
[ 2.45000005 -2.         -2.        ]
[ 1 -1 -1]
[ 2 -1 -1]
[150  50 100]
[  3.51570624e-16   2.00000000e+00  -1.00000000e+00]

也就是说,您的树有3个节点,第一个节点将功能2的值与2.45等进行比较。

根节点,左侧和右侧叶子中的值分别为3e-162-1

这些值虽然不易解释,因为树试图预测GBDT损失函数的梯度。