监视进度/从purrr :: invoke_map()打印到控制台

时间:2017-10-01 21:07:40

标签: r dplyr r-caret tidyverse purrr

我正在尝试使用caret在R中的列表列格式中训练多个purrr::invoke_map()模型(请参阅:this blog post)。

致电invoke_map()时,我希望能够以某种方式监控进度。具体来说,我想打印行号,或id列,invoke_map()列遍历模型/数据组合。有没有办法做到这一点,可能是通过修改训练功能(linearRegModel()下面的?)

library(tidyverse)
library(mlbench)
library(caret)
data("BostonHousing") # from mlbench


starter_df <- 
  list(BostonHousing) %>% 
  rep(3) %>% 
  enframe(name = 'id', value = 'rawdata')  %>% 
  transmute(
    id
    , train.X = map(rawdata,  ~ .x %>% select(-medv))
    , train.Y = map(rawdata, ~ .x$medv)
  )



# re-write any caret training method as a function. 
# using linear regression here for simplicity
linearRegModel <- function(X, Y) {
 ctrl <- trainControl(
    method = "repeatedcv", 
    number = 2
  )
  train(
    x = X,
    y = Y,
    method = 'lm',
    trControl = ctrl,
    preProc = c('center', 'scale')
  )
}


# convert models to tibble
model_list <- 
  list(linearRegModel = linearRegModel,
       linearRegModel2 = linearRegModel) %>%
  enframe(name = 'modelName',value = 'model')

# combine model tibble with the data tibble
train_df <- 
  starter_df[rep(1:nrow(starter_df),nrow(model_list)),] %>% 
  bind_cols(
    model_list[rep(1:nrow(model_list),nrow(starter_df)),] %>% arrange(modelName)
  ) %>%
  mutate(id=1:nrow(.))

train_df



# A tibble: 6 x 5
     id                 train.X     train.Y       modelName  model
  <int>                  <list>      <list>           <chr> <list>
1     1 <data.frame [506 x 13]> <dbl [506]>  linearRegModel  <fun>
2     2 <data.frame [506 x 13]> <dbl [506]>  linearRegModel  <fun>
3     3 <data.frame [506 x 13]> <dbl [506]>  linearRegModel  <fun>
4     4 <data.frame [506 x 13]> <dbl [506]> linearRegModel2  <fun>
5     5 <data.frame [506 x 13]> <dbl [506]> linearRegModel2  <fun>
6     6 <data.frame [506 x 13]> <dbl [506]> linearRegModel2  <fun>


# train models by calling invoke_map()
# (takes a few seconds)
data_with_model_fits <-
  train_df %>%
  mutate(params = map2(train.X, train.Y,  ~ list(X = .x, Y = .y)),
         modelFits = invoke_map(model,params)
  )

1 个答案:

答案 0 :(得分:1)

您可能会发现progress包很有意思。下面我把它整合到你的问题中。请注意两件事:

  • 在开始使用progress::progress_bar(tick = number_of_ticks)拟合模型之前初始化进度条。

  • linRegModel()功能中,您可以在模型适合pb$tick()后“勾选”进度条。

pb是使用面向对象技术的R6对象,因此您不必将其作为参数传递给linRegModel()函数。

希望它有所帮助。

library(tidyverse)
library(mlbench)
library(caret)

data("BostonHousing") # from mlbench

library(progress)

starter_df <- 
    list(BostonHousing) %>% 
    rep(3) %>% 
    enframe(name = 'id', value = 'rawdata')  %>% 
    transmute(
        id
        , train.X = map(rawdata,  ~ .x %>% select(-medv))
        , train.Y = map(rawdata, ~ .x$medv)
    )



# re-write any caret training method as a function. 
# using linear regression here for simplicity
linearRegModel <- function(X, Y) {
    ctrl <- trainControl(
        method = "repeatedcv", 
        number = 2
    )
    train(
        x = X,
        y = Y,
        method = 'lm',
        trControl = ctrl,
        preProc = c('center', 'scale')
    )

    # Tick the progress bar forward 1 tick after each completed model fit
    pb$tick()
}


# convert models to tibble
model_list <- 
    list(linearRegModel = linearRegModel,
         linearRegModel2 = linearRegModel) %>%
    enframe(name = 'modelName',value = 'model')

# combine model tibble with the data tibble
train_df <- 
    starter_df[rep(1:nrow(starter_df),nrow(model_list)),] %>% 
    bind_cols(
        model_list[rep(1:nrow(model_list),nrow(starter_df)),] %>% arrange(modelName)
    ) %>%
    mutate(id=1:nrow(.))

train_df


# initialize progress bar
ticks <- nrow(train_df)
pb <- progress::progress_bar$new(total = ticks)

# train models by calling invoke_map()
# (takes a few seconds)
data_with_model_fits <-
train_df %>%
mutate(params = map2(train.X, train.Y,  ~ list(X = .x, Y = .y)),
       modelFits = invoke_map(model,params)
)

为了增加灵活性,您可以在创建进度条时通过token参数使用format。内置了一些,如:current,以显示当前的迭代。这可能更直接地回答您的问题。在这种情况下,我会在模型运行之前调用pb$tick()。该文档还建议在长时间运行计算之前运行pb$tick(0)以立即显示进度条。

# initialize progress bar
pb <- progress::progress_bar$new(format = "running model :current", show_after = .01)
pb$tick(0)