如何获取scikit-learn决策树的所有节点的pos / neg实例计数?

时间:2015-08-13 14:46:27

标签: python scikit-learn decision-tree

我已经训练了一个sklearn决策树。

from sklearn.tree import DecisionTreeClassifier
c=DecisionTreeClassifier(class_weight="auto")
c.fit([[0,0],
       [0,1],
       [1,1],
      ],[0,1,0])

现在,我想检查每个节点有多少正/负样本。因此像

这样的图表
  counts: [2,1]            labels: (010)
                                 split by x0
    [1,1]       [1,0]         (01)        (0)
                           split by x1
 [1,0] [0,1]      0        (0)   (1)
   0     1

如何从训练有素的决策树中获得此(左计数)?

我可以看到一个c.tree_变量,但内容似乎不是很有帮助。有零,权重......而且很难猜测如何获得计数。

1 个答案:

答案 0 :(得分:1)

每个类的样本数存储在tree_.value中,但它只存储叶的节点值,因此我使用后序遍历来获取所有节点的值。

import numpy as np

def get_value(dt):
    left = dt.tree_.children_left
    right = dt.tree_.children_right
    value = dt.tree_.value
    leaves = np.argwhere(left == -1)[:, 0]

    def visit(node):
        if node in leaves:
            return
        visit(left[node])
        visit(right[node])
        value[node, :] = value[left[node], :] + value[right[node], :]

    visit(0)
    return value

例如,

from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier()
dt.fit([[0,0],
        [0,1],
        [1,1]], [0,1,0])
get_value(dt)

输出:

[[[ 2.  1.]]

 [[ 1.  1.]]

 [[ 1.  0.]]

 [[ 0.  1.]]

 [[ 1.  0.]]]

更新#1

我想知道为什么tree_.value只存储叶节点的值,然后我找到了https://stackoverflow.com/questions/27417809/show-values-at-each-node-level-of-scikit-learn-decision-treethis issue

事实证明,在scikit-learn 0.17.dev0中,tree_.value已经返回所有节点的值。

In [1]: from sklearn.tree import DecisionTreeClassifier

In [2]: dt = DecisionTreeClassifier()

In [3]: dt.fit([[0,0],
   ...:         [0,1],
   ...:         [1,1]], [0,1,0])
Out[3]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            random_state=None, splitter='best')

In [4]: dt.tree_.value
Out[4]:
array([[[ 2.,  1.]],

       [[ 1.,  1.]],

       [[ 1.,  0.]],

       [[ 0.,  1.]],

       [[ 1.,  0.]]])

更新#2

虽然我认为对"取消加权"当给出class_weight时,可以实现这一点。

class_weight

计算
In [1]: from sklearn.utils import compute_class_weight

In [2]: compute_class_weight('auto', [0, 1], [0, 1, 0])
Out[2]: array([ 0.66666667,  1.33333333])

因此,您可以在value[node, :] /= class_weight之后添加if node in leaves:以重新计算叶节点的值。