使用循环极慢地在R中生成交互式偏倚图

时间:2019-03-21 19:51:12

标签: r ggplot2 partial ggplotly

我正在尝试通过遍历数据集中的列来生成交互式的部分依赖图。

可复制的示例:

library(pdp)
library(xgboost)
library(Matrix)
library(ggplot2)
library(plotly)

data(mtcars)
target <- mtcars$mpg
mtcars$mpg <- NULL

mtcars.sparse <- sparse.model.matrix(target~., mtcars)

fit <- xgboost(data=mtcars.sparse, label=target, nrounds=100)

for (i in seq_along(names(mtcars))){
  p1 <- pdp::partial(fit,
                     pred.var = names(mtcars)[i],
                     pred.grid = data.frame(unique(mtcars[names(mtcars)[i]])),
                     train = mtcars.sparse,
                     type = "regression",
                     cats = c("cyl", "vs", "am", "gear", "carb"),
                     plot = FALSE)
  p2 <- ggplot(aes_string(x = names(mtcars)[i] , y = "yhat"), data = p1) +
    geom_line(color = '#E51837', size = .6) +
    labs(title = paste("Partial Dependence plot of", names(mtcars)[i] , sep = " ")) +
    theme(text = element_text(color = "#444444", family = 'Helvetica Neue'),
          plot.title = element_text(size = 13, color = '#333333'))

  print(ggplotly(p2, tooltip = c("x", "y")))

}

在我的真实数据集(约22k行,30列)上的绘图循环大约需要2个小时。关于如何加快速度的任何想法?

1 个答案:

答案 0 :(得分:1)

由于在R中使用数据结构的方式,如果不注意,for()循环可能会非常慢。如果您想进一步了解其背后的技术原因,请查看Hadley Wickham的Advanced R

实际上,有两种主要方法可以加快您的工作速度:优化for()循环和使用apply()函数系列。尽管两种方法都可以很好地工作,但是apply()方法往往比优化编写的for()循环更快,所以我会坚持使用该解决方案。

apply方法:

plotFunction <- 
  function(x) {
    p1 <- pdp::partial(fit,
                       pred.var = x,
                       pred.grid = data.frame(unique(mtcars[x])),
                       train = mtcars.sparse,
                       type = "regression",
                       cats = c("cyl", "vs", "am", "gear", "carb"),
                       plot = FALSE)
    p2 <- ggplot(aes_string(x = x , y = "yhat"), data = p1) +
      geom_line(color = '#E51837', size = .6) +
      labs(title = paste("Partial Dependence plot of", x , sep = " ")) +
      theme(text = element_text(color = "#444444", family = 'Helvetica Neue'),
            plot.title = element_text(size = 13, color = '#333333'))
    return(p2)
  }


plot.list <- lapply(varNames, plotFunction)

system.time(lapply(varNames, plotFunction))
   user  system elapsed 
  0.471   0.004   0.488 

在您的for()循环上运行相同的基准测试可以得出:

   user  system elapsed 
  3.945   0.616   3.519 

您会注意到,只需将循环代码粘贴到一个函数中,并进行少量修改,就可以将速度提高10倍。

如果您想提高速度,可以对功能进行一些调整,但是apply()方法最强大的方面可能是它很适合并行化,这可以通过包来完成像pbmcapply

实施pbmcapply可为您提供更高的速度;

library(pdp)
library(xgboost)
library(Matrix)
library(ggplot2)
library(plotly)
library(pbmcapply)

# Determines the number of cores you want to use for paralell processing
# I like to leave two of mine available, but you can get away with 1
nCores <-  detectCores() - 1

data(mtcars)
target <- mtcars$mpg
mtcars$mpg <- NULL

mtcars.sparse <- sparse.model.matrix(target~., mtcars)

fit <- xgboost(data=mtcars.sparse, label=target, nrounds=100)

varNames <- 
  names(mtcars) %>%
  as.list

plotFunction <- 
  function(x) {
    p1 <- pdp::partial(fit,
                       pred.var = x,
                       pred.grid = data.frame(unique(mtcars[x])),
                       train = mtcars.sparse,
                       type = "regression",
                       cats = c("cyl", "vs", "am", "gear", "carb"),
                       plot = FALSE)
    p2 <- ggplot(aes_string(x = x , y = "yhat"), data = p1) +
      geom_line(color = '#E51837', size = .6) +
      labs(title = paste("Partial Dependence plot of", x , sep = " ")) +
      theme(text = element_text(color = "#444444", family = 'Helvetica Neue'),
            plot.title = element_text(size = 13, color = '#333333'))
    return(p2)
  }


plot.list <- pbmclapply(varNames, plotFunction, mc.cores = nCores)

看看效果如何

   user  system elapsed 
  0.842   0.458   0.320 

lapply()略有改进,但该改进应随更大的数据集而定。希望这会有所帮助!