如何从sklearn决策树收集所有路径?

时间:2019-05-16 13:01:06

标签: python scikit-learn binary-tree treepath

我正在尝试从Skealrn中的决策树生成所有路径。 这里的estimator来自随机森林,它是sklearn中的决策树。 但是我对sklearn决策树的数据结构感到困惑。似乎leftright这里包含了所有的左节点。

当我尝试打印路径时,它可以正常工作。

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                print ('\t' * tabdepth + "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node], tabdepth+1)
                print ('\t' * tabdepth + "} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node], tabdepth+1)
                print ('\t' * tabdepth + "}")
            else:
                print ('\t' * tabdepth + "return " + str(value[node]))

    recurse(left, right, threshold, features, 0)

但是我需要收集列表中的所有路径,如果叶子节点为“正常”,也就不记录路径,所以我尝试了以下代码:

def extract_attack_rules(estimator, feature_names):
    left      = estimator.tree_.children_left
    right     = estimator.tree_.children_right
    threshold = estimator.tree_.threshold
    features  = [feature_names[i] for i in estimator.tree_.feature]
    value = estimator.tree_.value

    def recurse(left, right, threshold, features, node):
        path_lst = []

        if threshold[node] != -2:  # not leaf node
            left_cond = features[node]+"<="+str(threshold[node])
            right_cond = features[node]+">"+str(threshold[node])

            if left[node] != -1:  # not leaf node
                left_path_lst = recurse(left, right, threshold, features,left[node])
            if right[node] != -1:  # not leaf node
                right_path_lst = recurse(left, right, threshold, features,right[node])

            if left_path_lst is not None:
                path_lst.extend([left_path.append(left_cond) for left_path in left_path_lst])

            if pre_right_path is not None:
                path_lst.extend([right_path.append(right_cond) for right_path in right_path_lst])
            return path_lst

        else:  # leaf node, the attack type
            if value[node][0][0] > 0:  # if leaf is normal, not collect this path
                return None
            else:  # attack
                for i in range(len(value[node][0])):
                    if value[node][0][i] > 0:
                        return [[value[node][0][i]]]

    all_path = recurse(left, right, threshold, features, 0)

    return all_path

它返回一个超级巨无霸的结果,那就是没有足够的内存要加载,我敢肯定这里的代码有问题,因为所有需要的路径都不应该那么大。我也在这里尝试了方法:Getting decision path to a node in sklearn,但是sklearn树结构的输出只会让我更加困惑。

您知道如何在这里解决问题吗?

0 个答案:

没有答案