遍历二叉树以获得拆分条件 - ctree(聚会),递归函数

时间:2016-02-11 05:01:09

标签: r binary-tree party

我正在尝试使用常规数据集重现我使用数据集获得的错误。如果我错过了什么,请纠正我。

使用Classification拟合library(party)树后,我试图在每个节点上获得树的分割条件。我设法写了一个代码,我相信它工作正常,直到我发现了一个错误。有人可以帮我解决吗?

我的代码:

 require(party)
 iris$Petal.Width <- as.factor(iris$Petal.Width)#imp to convert to factorial
 (ct <- ctree(Species ~ ., data = iris))
 plot(ct)
 #print(ct)
  a <- ct #convert it to s4 object
  t <- a@tree

 #recursive function to traverse the tree and get the splitting conditions

 recurse_tree <- function(tree,ret_list=list(),sub_list=list()){
   if(!tree$terminal){
     sub_list$assign <-list(tree$psplit$splitpoint,tree$psplit$variableName,class(tree$psplit))
     names(sub_list)[which(names(sub_list)=="assign")] <- paste("node",tree$nodeID,sep="")
     ret_list <- recurse_tree(tree$left, ret_list, sub_list)
     ret_list <- recurse_tree(tree$right, ret_list, sub_list)

  }
  if(tree$terminal){
    ret_list$assign <- c(sub_list, tree$prediction)
    names(ret_list)[which(names(ret_list)=="assign")] <- paste("node",tree$nodeID,sep="")
    return(ret_list)
   }
   return(ret_list)
  }

  result <- recurse_tree(t) #call to the functions

现在,结果给出了所有节点和拆分条件和预测的列表(我假设)。但是,当我检查

的分割条件时
  • Node5上的预期输出:{1.1,1.2,1.6,1.7} # from printing the tree print(ct), I get this
  • 输出我从我的函数接收 Node5:{“1”,“1.3”,“1.4”,“1.5”}这基本上是Node6的分割条件,这是错误。我怎么得到这个?

    z <- result[2] #I know node5 is second in the list

    z <- unlist(z,recursive = F,use.names = T) #unlist
    levels(z[[3]][[1]]) [which((z[[3]][[1]])==0)] #to find levels of corresponding values
    

plot for more details

我怀疑,我的函数(recurse_tree)总是给我正确的终端节点而不是左节点的分割条件。任何帮助将不胜感激。

0 个答案:

没有答案