在数据框中查找为树模型中的节点传递规则的数据元素?

时间:2014-05-29 00:41:32

标签: r rpart

所以我使用了rpart包创建了一个树模型,我找到了一个有趣的规则,并想知道是否有一种简单的方法可以看出该数据框中的哪些观察结果通过了该规则。

使用path.rpart查找它从树上取下的路径,并手动将这些过滤器输入到数据框中以查找它们似乎非常繁琐。有没有一种方法可以传递树和/或节点,以及数据帧并返回该帧中以该节点结束的所有元素?

2 个答案:

答案 0 :(得分:9)

我修改了path.rpart中的代码,以返回属于特定节点的数据子集,而不是返回有关该节点的信息。它可以通过点击绘图或通过path.rpart函数传递节点来工作。这是代码

subset.rpart <- function (tree, df, nodes) {
    if (!inherits(tree, "rpart")) 
        stop("Not a legitimate \"rpart\" object")
    stopifnot(nrow(df)==length(tree$where))
    frame <- tree$frame
    n <- row.names(frame)
    node <- as.numeric(n)

    if (missing(nodes)) {
        xy <- rpart:::rpartco(tree)
        i <- identify(xy, n = 1L, plot = FALSE)
        if(i> 0L) {
             return( df[tree$where==i, ] )
        } else {
            return(df[0,])
        }
    }
    else {
        if (length(nodes <- rpart:::node.match(nodes, node)) == 0L) 
            return(df[0,])
        return ( df[tree$where %in% as.numeric(nodes), ] )
    }
}

我将在包

中的一些样本数据上使用它
fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit)
text(fit)

rpart tree plot

然后在特定节点上找到观察结果,运行

subset.rpart(fit, kyphosis)

并单击绘图上的节点。完成后,将返回该节点上的所有观察结果。您必须使用用于建模的相同data.frame才能正常工作。您也可以使用path.rpart

传递您发现的节点名称,而不是单击某个点
# path.rpart(fit)  
#  node number: 10  ---> looks interesting
#    root
#    Start>=8.5
#    Start< 14.5
#    Age< 55

subset.rpart(fit, kyphosis, 10)
#    Kyphosis Age Number Start
# 14   absent   1      4    12
# 20   absent  27      4     9
# 26   absent   9      5    13
# 37   absent   1      3     9
# 39   absent  20      6     9
# 42   absent  35      3    13
# 57   absent   2      3    13
# 59   absent  51      7     9
# 66   absent  17      4    10
# 69   absent  18      4    11
# 78   absent  26      7    13
# 81   absent  36      4    13

答案 1 :(得分:0)

#' subset of rpart node: return logical index
#' @param tree rpart model
#' @param node which node/leaf?
#' @export
subset_rpart <- function (tree, node) {
  nodes = as.numeric(rownames(tree$frame))
  nodes = log(nodes, 2)
  lower = log(node, 2)
  upper = log(node + 1, 2)
  a = floor(lower)
  lower_ = lower - a
  upper_  = upper - a
  nodes_ = nodes %% 1
  w = which(((nodes_ >= lower_ & nodes_ < upper_) | (nodes_ + 1 < upper_)) & nodes >= lower)
  tree$where %in% w
}



#' subset df by subset_rpart
#' @param tree rpart model
#' @param node node number
#' @param df df
#' @export
subset.rpart = function(tree, node, df){
  df[subset_rpart(tree, node), ]
}