从xgboost.dump中找到二叉树的所有路径

时间:2017-03-22 03:22:37

标签: python decision-tree xgboost

我有许多树的xgboost.dump文本文件。 我想找到为每条路径获取价值的所有路径。 这是一棵树。

tree[0]:
0:[a<0.966398] yes=1,no=2,missing=1
    1:[b<0.323071] yes=3,no=4,missing=3
        3:[c<0.461248] yes=7,no=8,missing=7
            7:leaf=0.00972768
            8:leaf=-0.0179376
        4:[a<0.379082] yes=9,no=10,missing=9
            9:leaf=0.0146003
            10:leaf=0.0454369
    2:[b<0.322352] yes=5,no=6,missing=5
        5:[c<0.674868] yes=11,no=12,missing=11
            11:leaf=0.0497964
            12:leaf=0.00953781
        6:[f<0.598267] yes=13,no=14,missing=13
            13:leaf=0.0504545
            14:leaf=0.0867654

我想将所有路径转换为

path1, a<0.966398, b<0.323071, c<0.461248, leaf = 0.00097268
path2, a<0.966398, b<0.323071, c>0.461248, leaf = -0.0179376
path3, a<0.966398, b>0.323071, a<0.379082, leaf = 0.0146003
path4, a<0.966398, b>0.323071, a>0.379082, leaf = 0.0454369
path5, a>0.966398, b<0.322352, c<0.674868, leaf = 0.0497964
path6, a>0.966398, b<0.322352, c>0.674868, leaf = 0.00953781
path7, a>0.966398, b>0.322352, f<0.598267, leaf = 0.0504545
path8, a>0.966398, b>0.322352, f>0.598267, leaf = 0.0864654

我已经尝试列出所有可能的路径,如

array([[ 0,  1,  3,  7],
       [ 0,  1,  3,  8],
       [ 0,  1,  4,  9],
       [ 0,  1,  4, 10],
       [ 0,  2,  5, 11],
       [ 0,  2,  5, 12],
       [ 0,  2,  6, 13],
       [ 0,  2,  6, 14]])

但是一旦max_depth更高,这种方式会导致错误,某些分支将停止增长并且路径将是错误的。 所以我需要在文本文件中解析yes,no来生成真正的,正确的路径。 有什么建议? 谢谢!

1 个答案:

答案 0 :(得分:1)

以下是我使用R实现解决此问题的方法。其他语言的用户可以遵循逻辑并以实物形式复制。

首先,我开始使用xgb.model.dt.tree()生成的模型转储文件。

然后,我编写了一个函数来解析从转储模型的单个树中的任意节点到最终父节点的有效路径。

稍后,我使用purrr :: by_row()将此函数应用于模型转储中的所有终端节点“Leaf”记录,并将结果转换为目的。

此函数有两个参数,一个用于正在测试的树,另一个用于终端节点的标识。它遵循以下一般步骤:

  1. 从基于每个树的目标(终端)节点开始,在c(“是”,“否”,“缺失”)决策拆分中找到具有目标节点作为有效子节点的行。
  2. 将此有效父节点ID连接到一个向量中,该向量将用于跟踪从目标节点到最终父节点的路径的每个步骤。该函数在函数完成时返回。
  3. 接下来,为链上的每个节点重复“谁是我的父”步骤,直到路径到达最终父节点(此节点ID始终以“-0”结尾),同时更新每个新步骤的路径向量链条。
  4. 一旦功能命中终端节点,返回()路径。
  5. 在我的例子中,我使用purrr :: by_row()将此函数应用于模型转储中的所有“Leaf”节点,而.collat​​ing =“rows”将此路径表示为输出中的附加行。

    这也很可能不是最快的方法。

    xgb.booster模型中nrounds或max_depth的增加将导致此过程的运行时间增加。您可以使用树的子集(xgb.model.dt.tree()的参数n_first_tree = N)开发方法,以便估计在最终模型中分析整个终端节点路径所需的时间。在我的例子中,max_depth = 5的约500棵树的模型可能需要30分钟。