我们当前项目中使用的决策树使用条件推理(C Tree)算法。我可以使用下面的代码提取二进制c-tree的split变量:
#develop ctree decision tree
prod_discount_data_ctree <- ctree(Discount~Prod, data=prod_discount_data, controls = ctree_control(minsplit=30))
plot(prod_discount_data_ctree)
#extract the left and right terminal node split rule
lvls <- levels(prod_discount_data_ctree@tree$psplit$splitpoint)
#left leaf node split variable
left.df = lvls[prod_discount_data_ctree@tree$psplit$splitpoint == 1]
#right leaf node split variable
right.df = lvls[prod_discount_data_ctree@tree$psplit$splitpoint == 0]
如果树只有一个节点(深度= 1),它可以分成2个叶子节点,则可以正常工作。但是如果树有一个节点(节点1)分成多个节点(节点2,5),这些节点进一步分裂成叶节点(节点2 {3,4}节点5 {6,7}),我该如何穿越更深层次并获取叶节点拆分变量? 根据这个例子,我希望以4个列表的形式为节点3,4,6,7分割变量。
答案 0 :(得分:0)
我尝试了所有可能的选项,最后找到了一种在C树内遍历的方法,并获得每个叶节点的拆分变量。如果有人想在将来推荐,请粘贴代码段。
if (nrow(SubBrandright_total) > 200) {
sec_discount_data <- subset(SubBrandright_total, select=c(Discount,Sector))
sec_discount_data_ctree <- ctree(Discount~Sector, data=sec_discount_data, controls = ctree_control(minsplit=30))
sec_lvls_r <- levels(sec_discount_data_ctree@tree$psplit$splitpoint)
#Testing if the node is terminal [TRUE] or not [FALSE]
#print(sec_discount_data_ctree@tree$terminal)
#print(sec_discount_data_ctree@tree$left$terminal)
#print(sec_discount_data_ctree@tree$left$left$terminal)
#print(sec_discount_data_ctree@tree$left$right$terminal)
sec_left_left.df = sec_lvls_r[sec_discount_data_ctree@tree$left$psplit$splitpoint == 1]
sec_left.df = sec_lvls_r[sec_discount_data_ctree@tree$psplit$splitpoint == 1]
#Using setdiff to get right leaf node from Node minus left leaf node
sec_left_right.df = setdiff(sec_left.df,sec_left_left.df)
print("Sector Segmentation")
print(sec_left_left.df)
print(sec_left_right.df)
sec_right.df = sec_lvls_r[sec_discount_data_ctree@tree$psplit$splitpoint == 0]
sec_right_right.df = sec_lvls_r[sec_discount_data_ctree@tree$right$psplit$splitpoint == 0]
#Using setdiff to get left leaf node from Node minus right leaf node
sec_right_left.df = setdiff(sec_right.df,sec_right_right.df)
print(sec_right_left.df)
print(sec_right_right.df)
}