partykit中节点的深度

时间:2016-02-08 08:53:49

标签: r binary-tree party

我正在使用partykit R包构建树,我想知道是否有一种简单有效的方法来确定每个内部节点的深度数。例如,根节点的深度为0,前两个kid节点的深度为1,下一个kid节点的深度为2,依此类推。这最终将用于计算变量的最小深度。以下是一个非常基本的示例(取自vignette("constparty", package="partykit")):

library("partykit")
library("rpart")
data("Titanic", package = "datasets")
ttnc<-as.data.frame(Titanic)
ttnc <- ttnc[rep(1:nrow(ttnc), ttnc$Freq), 1:4]
names(ttnc)[2] <- "Gender"
rp <- rpart(Survived ~ ., data = ttnc)
ttncTree<-as.party(rp)
plot(ttncTree)

#This is one of my many attempts which does NOT work
internalNodes<-nodeids(ttncTree)[-nodeids(ttncTree, terminal = TRUE)]
depth(ttncTree)-unlist(nodeapply(ttncTree, ids=internalNodes, FUN=function(n){depth(n)}))

在这个例子中,我想输出类似于:

的内容
nodeid = 1 2 4 7 
depth  = 0 1 2 1

如果我的问题太具体,我道歉。

2 个答案:

答案 0 :(得分:3)

这是一个可能的解决方案,应该足够有效,因为树通常不超过几十个节点。 我忽略了节点#1,因为它始终为0,因此无论是计算它还是显示它都没有意义(IMO)

Inters <- nodeids(ttncTree)[-nodeids(ttncTree, terminal = TRUE)][-1]
table(unlist(sapply(Inters, function(x) intersect(Inters, nodeids(ttncTree, from = x)))))
# 2 4 7 
# 1 2 1 

答案 1 :(得分:0)

我最近不得不重新考虑这个问题。下面是确定每个节点深度的函数。我根据运行|函数的垂直线print.party()出现的次数来计算深度。

library(stringr)
idDepth <- function(tree) {
  outTree <- capture.output(tree)
  idCount <- 1
  depthValues <- rep(NA, length(tree))
  names(depthValues) <- 1:length(tree)
  for (index in seq_along(outTree)){
    if (grepl("\\[[0-9]+\\]", outTree[index])) {
      depthValues[idCount] <- str_count(outTree[index], "\\|")
      idCount = idCount + 1
    }
  }
  return(depthValues)
}

> idDepth(ttncTree)
1 2 3 4 5 6 7 8 9 
0 1 2 2 3 3 1 2 2

肯定有一个更简单,更快速的解决方案,但这比使用intersect()函数要快。下面是一棵大树(大约1,500个节点)的计算时间的示例

# Compare computation time for large tree #
library(mlbench)
set.seed(470174)
dat <- data.frame(mlbench.friedman1(5000))
rp <- rpart(as.formula(paste0("y ~ ", paste(paste0("x.", 1:10), collapse=" + "))),
            data=dat, control = rpart.control(cp = -1, minsplit=3, maxdepth = 10))
partyTree <- as.party(rp)

> length(partyTree) #Number of splits
[1] 1503
> 
> # Intersect() computation time
> Inters <- nodeids(partyTree)[-nodeids(partyTree, terminal = TRUE)][-1]
> system.time(table(unlist(sapply(Inters, function(x) intersect(Inters, nodeids(partyTree, from = x))))))
   user  system elapsed 
  22.38    0.00   22.44 
> 
> # Proposed computation time
> system.time(idDepth(partyTree))
   user  system elapsed 
   2.38    0.00    2.38