提取每个终端节点的路径

时间:2017-06-18 22:24:56

标签: python json nested xgboost

我有一个python嵌套字典结构,如下所示。  这是一个小例子,但我有更大的例子,可以有不同的嵌套级别。

由此,我需要提取一个列表:

  1. 每个终端“叶子”节点的一条记录
  2. 表示通向该节点的逻辑路径的字符串,列表或对象
    • (例如'nodeid_3:X< 0.500007和X< 0.279907')
  3. 本周末我花了大部分时间试图让事情变得有效,并且意识到我对递归有多么糟糕。

    columns

1 个答案:

答案 0 :(得分:1)

您的数据结构是递归的。如果某个节点有 children 键,那么我们可以认为它不是终端。

要分析数据,您需要一个跟踪祖先的递归函数(路径)。

我会这样实现:

def find_path(obj, path=None):
    path = path or []
    if 'children' in obj:
        child_obj = {k: v for k, v in obj.items()
                     if k in ['nodeid', 'split_condition']}
        child_path = path + [child_obj]
        children = obj['children']
        for child in children:
            find_path(child, child_path)
    else:
        pprint.pprint((obj, path))

如果你打电话:

find_path(data)

你得到3个结果:

({'cover': 2291, 'leaf': -0.0611795, 'nodeid': 3},
 [{'nodeid': 0, 'split_condition': 0.500007},
  {'nodeid': 1, 'split_condition': 0.279907}])
({'cover': 1779, 'leaf': -0.00965727, 'nodeid': 4},
 [{'nodeid': 0, 'split_condition': 0.500007},
  {'nodeid': 1, 'split_condition': 0.279907}])
({'cover': 3930, 'leaf': -0.0611946, 'nodeid': 2},
 [{'nodeid': 0, 'split_condition': 0.500007}])

当然,您可以将pprint.pprint()的调用替换为yield,以将此功能转换为生成器:

def iter_path(obj, path=None):
    path = path or []
    if 'children' in obj:
        child_obj = {k: v for k, v in obj.items()
                     if k in ['nodeid', 'split_condition']}
        child_path = path + [child_obj]
        children = obj['children']
        for child in children:
            # for o, p in iteration_path(child, child_path):
            #     yield o, p
            yield from iter_path(child, child_path)
    else:
        yield obj, path

请注意递归调用的yield from用法。你可以像下面这样使用这个生成器:

for obj, path in iter_path(data):
    pprint.pprint((obj, path))

您还可以更改child_obj对象的构建方式以满足您的需求。

保持对象的顺序:反转if条件:if 'children' not in obj: …