partykit:将终端节点箱形图更改为显示均值和标准差的条形图

时间:2018-12-20 15:41:34

标签: r tree party

我在R中创建了一个回归树。 这是代码:

tree <- rpart(y~., method="anova", minsplit=20, minbucket=20, maxdepth=3, data=foo)

plot(as.party(tree), terminal_panel = node_boxplot)

我希望我的终端节点包括一个条形图,而不是显示中位数和IQR的箱形图,其中一个条形图显示平均值和标准偏差误差条形图。我已经测试了所有terminal_panel选项,但是没有一个可以这样做。 有什么建议吗?

1 个答案:

答案 0 :(得分:4)

就我所知,这种面板功能并不存在,因此我已经编写了一个。请参阅下面的node_dynamite()。这样就可以做到:

library("rpart")
library("partykit")
p <- as.party(rpart(dist ~ speed, data = cars))
plot(p, terminal_panel = node_dynamite)

node_dynamite

node_dynamite <- function(obj, factor = 1,
                          col = "black",
                          fill = "lightgray",
                          bg = "white",
                          width = 0.5,
                          yscale = NULL,
                          ylines = 3,
                          cex = 0.5,
                          id = TRUE,
                          mainlab = NULL, 
                          gp = gpar())
{
    ## observed data/weights and tree fit
    y <- obj$fitted[["(response)"]]
    stopifnot(is.numeric(y))
    g <- obj$fitted[["(fitted)"]]
    w <- obj$fitted[["(weights)"]]
    if(is.null(w)) w <- rep(1, length(y))

    ## (weighted) means and standard deviations by node
    n <- tapply(w, g, sum)
    m <- tapply(y * w, g, sum)/n
    s <- sqrt(tapply((y - m[factor(g)])^2 * w, g, sum)/(n - 1))

    if (is.null(yscale)) 
        yscale <- c(min(c(0, (m - factor * s) * 1.1)), max(c(0, (m + factor * s) * 1.1)))

    ### panel function for boxplots in nodes
    rval <- function(node) {

        ## extract data
        nid <- id_node(node)
        mid <- m[as.character(nid)]
        sid <- s[as.character(nid)]
        wid <- n[as.character(nid)]

        top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3,
                           widths = unit(c(ylines, 1, 1), 
                                         c("lines", "null", "lines")),  
                           heights = unit(c(1, 1), c("lines", "null"))),
                           width = unit(1, "npc"), 
                           height = unit(1, "npc") - unit(2, "lines"),
               name = paste("node_dynamite", nid, sep = ""),
               gp = gp)

        pushViewport(top_vp)
        grid.rect(gp = gpar(fill = bg, col = 0))

        ## main title
        top <- viewport(layout.pos.col=2, layout.pos.row=1)
        pushViewport(top)
        if (is.null(mainlab)) { 
      mainlab <- if(id) {
        function(id, nobs) sprintf("Node %s (n = %s)", id, nobs)
      } else {
        function(id, nobs) sprintf("n = %s", nobs)
      }
        }
    if (is.function(mainlab)) {
          mainlab <- mainlab(names(obj)[nid], wid)
    }
        grid.text(mainlab)
        popViewport()

        plot <- viewport(layout.pos.col = 2, layout.pos.row = 2,
                         xscale = c(0, 1), yscale = yscale,
             name = paste0("node_dynamite", nid, "plot"),
             clip = FALSE)

        pushViewport(plot)

        grid.yaxis()
        grid.rect(gp = gpar(fill = "transparent"))
    grid.clip()

    xl <- 0.5 - width/8
    xr <- 0.5 + width/8

        ## box & whiskers
        grid.rect(unit(0.5, "npc"), unit(0, "native"), 
                  width = unit(width, "npc"), height = unit(mid, "native"),
                  just = c("center", "bottom"), 
                  gp = gpar(col = col, fill = fill))
        grid.lines(unit(0.5, "npc"), 
                   unit(mid + c(-1, 1) * factor * sid, "native"), gp = gpar(col = col))
        grid.lines(unit(c(xl, xr), "npc"), unit(mid - factor * sid, "native"), 
                   gp = gpar(col = col))
        grid.lines(unit(c(xl, xr), "npc"), unit(mid + factor * sid, "native"), 
                   gp = gpar(col = col))

        upViewport(2)
    }

    return(rval)
}
class(node_dynamite) <- "grapcon_generator"