我正在使用partykit
R包构建树,我想知道是否有一种简单有效的方法来确定每个内部节点的深度数。例如,根节点的深度为0,前两个kid节点的深度为1,下一个kid节点的深度为2,依此类推。这最终将用于计算变量的最小深度。以下是一个非常基本的示例(取自vignette("constparty", package="partykit")
):
library("partykit")
library("rpart")
data("Titanic", package = "datasets")
ttnc<-as.data.frame(Titanic)
ttnc <- ttnc[rep(1:nrow(ttnc), ttnc$Freq), 1:4]
names(ttnc)[2] <- "Gender"
rp <- rpart(Survived ~ ., data = ttnc)
ttncTree<-as.party(rp)
plot(ttncTree)
#This is one of my many attempts which does NOT work
internalNodes<-nodeids(ttncTree)[-nodeids(ttncTree, terminal = TRUE)]
depth(ttncTree)-unlist(nodeapply(ttncTree, ids=internalNodes, FUN=function(n){depth(n)}))
在这个例子中,我想输出类似于:
的内容nodeid = 1 2 4 7
depth = 0 1 2 1
如果我的问题太具体,我道歉。
答案 0 :(得分:3)
这是一个可能的解决方案,应该足够有效,因为树通常不超过几十个节点。 我忽略了节点#1,因为它始终为0,因此无论是计算它还是显示它都没有意义(IMO)
Inters <- nodeids(ttncTree)[-nodeids(ttncTree, terminal = TRUE)][-1]
table(unlist(sapply(Inters, function(x) intersect(Inters, nodeids(ttncTree, from = x)))))
# 2 4 7
# 1 2 1
答案 1 :(得分:0)
我最近不得不重新考虑这个问题。下面是确定每个节点深度的函数。我根据运行|
函数的垂直线print.party()
出现的次数来计算深度。
library(stringr)
idDepth <- function(tree) {
outTree <- capture.output(tree)
idCount <- 1
depthValues <- rep(NA, length(tree))
names(depthValues) <- 1:length(tree)
for (index in seq_along(outTree)){
if (grepl("\\[[0-9]+\\]", outTree[index])) {
depthValues[idCount] <- str_count(outTree[index], "\\|")
idCount = idCount + 1
}
}
return(depthValues)
}
> idDepth(ttncTree)
1 2 3 4 5 6 7 8 9
0 1 2 2 3 3 1 2 2
肯定有一个更简单,更快速的解决方案,但这比使用intersect()
函数要快。下面是一棵大树(大约1,500个节点)的计算时间的示例
# Compare computation time for large tree #
library(mlbench)
set.seed(470174)
dat <- data.frame(mlbench.friedman1(5000))
rp <- rpart(as.formula(paste0("y ~ ", paste(paste0("x.", 1:10), collapse=" + "))),
data=dat, control = rpart.control(cp = -1, minsplit=3, maxdepth = 10))
partyTree <- as.party(rp)
> length(partyTree) #Number of splits
[1] 1503
>
> # Intersect() computation time
> Inters <- nodeids(partyTree)[-nodeids(partyTree, terminal = TRUE)][-1]
> system.time(table(unlist(sapply(Inters, function(x) intersect(Inters, nodeids(partyTree, from = x))))))
user system elapsed
22.38 0.00 22.44
>
> # Proposed computation time
> system.time(idDepth(partyTree))
user system elapsed
2.38 0.00 2.38