所以我使用了rpart包创建了一个树模型,我找到了一个有趣的规则,并想知道是否有一种简单的方法可以看出该数据框中的哪些观察结果通过了该规则。
使用path.rpart查找它从树上取下的路径,并手动将这些过滤器输入到数据框中以查找它们似乎非常繁琐。有没有一种方法可以传递树和/或节点,以及数据帧并返回该帧中以该节点结束的所有元素?
答案 0 :(得分:9)
我修改了path.rpart
中的代码,以返回属于特定节点的数据子集,而不是返回有关该节点的信息。它可以通过点击绘图或通过path.rpart
函数传递节点来工作。这是代码
subset.rpart <- function (tree, df, nodes) {
if (!inherits(tree, "rpart"))
stop("Not a legitimate \"rpart\" object")
stopifnot(nrow(df)==length(tree$where))
frame <- tree$frame
n <- row.names(frame)
node <- as.numeric(n)
if (missing(nodes)) {
xy <- rpart:::rpartco(tree)
i <- identify(xy, n = 1L, plot = FALSE)
if(i> 0L) {
return( df[tree$where==i, ] )
} else {
return(df[0,])
}
}
else {
if (length(nodes <- rpart:::node.match(nodes, node)) == 0L)
return(df[0,])
return ( df[tree$where %in% as.numeric(nodes), ] )
}
}
我将在包
中的一些样本数据上使用它fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit)
text(fit)
然后在特定节点上找到观察结果,运行
subset.rpart(fit, kyphosis)
并单击绘图上的节点。完成后,将返回该节点上的所有观察结果。您必须使用用于建模的相同data.frame
才能正常工作。您也可以使用path.rpart
# path.rpart(fit)
# node number: 10 ---> looks interesting
# root
# Start>=8.5
# Start< 14.5
# Age< 55
subset.rpart(fit, kyphosis, 10)
# Kyphosis Age Number Start
# 14 absent 1 4 12
# 20 absent 27 4 9
# 26 absent 9 5 13
# 37 absent 1 3 9
# 39 absent 20 6 9
# 42 absent 35 3 13
# 57 absent 2 3 13
# 59 absent 51 7 9
# 66 absent 17 4 10
# 69 absent 18 4 11
# 78 absent 26 7 13
# 81 absent 36 4 13
答案 1 :(得分:0)
#' subset of rpart node: return logical index
#' @param tree rpart model
#' @param node which node/leaf?
#' @export
subset_rpart <- function (tree, node) {
nodes = as.numeric(rownames(tree$frame))
nodes = log(nodes, 2)
lower = log(node, 2)
upper = log(node + 1, 2)
a = floor(lower)
lower_ = lower - a
upper_ = upper - a
nodes_ = nodes %% 1
w = which(((nodes_ >= lower_ & nodes_ < upper_) | (nodes_ + 1 < upper_)) & nodes >= lower)
tree$where %in% w
}
#' subset df by subset_rpart
#' @param tree rpart model
#' @param node node number
#' @param df df
#' @export
subset.rpart = function(tree, node, df){
df[subset_rpart(tree, node), ]
}