从R中的几个嵌套模型的列表列中获得预测

时间:2019-06-27 07:27:35

标签: r function purrr

我遵循了数据科学书籍R中的“许多模型”示例,并在社区(@Ronak)的帮助下设法获得了model_list。

但是,我想通过使用for循环或map函数从模型中获得预测,并避免代码中的重复。

我有我现有的代码,该代码基本上列出了每个模型名称并在验证集上进行预测。我尝试做如下的for循环,但无济于事:

现有数据集如下所示:

> model_fit
# A tibble: 10 x 16
   Location  data     data_2018   model_lab model_temp model_dep 
   <fct>     <list>   <list>      <list>    <list>     <list>             
 1 Location1  <tibble> <tibble [1~ <list [3~ <list [3]> <list [3]     
 2 Location2  <tibble> <tibble [1~ <list [3~ <list [3]> <list [3]      
 3 Location3 <tibble> <tibble [1~ <list [3~ <list [3]> <list [3]      
 4 Location4  <tibble> <tibble [1~ <list [3~ <list [3]> <list [3]     
Attempted Code:
model_list <- dplyr :: select(model_fit,model_lab:model_volume)
pred_names <- c('labour_pred','temp_pred','depre_pred','supplies_pred','salary_pred','expense_pred','train_pred',
                'travel_pred','outside_pred','utilities_pred','repair_pred','hours_pred','volume_pred')

for (c in seq_along(pred_names)) {
  model_pred <- model_fit %>%
    mutate(pred_names[c] =  map2(data_2018,model_list[c], function(x, y) 
      map(seq_along(y), function(i) 
        if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
        else as.numeric(predict(y[[i]])))))
}

Existing Code:
model_pred <- model_fit %>%
  mutate(pred_lab= map2(data_2018,model_lab, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_temp= map2(data_2018,model_temp, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_dep= map2(data_2018,model_dep, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_supplies= map2(data_2018,model_supplies, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_salary= map2(data_2018,model_salary, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_expense= map2(data_2018,model_expense, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_train= map2(data_2018,model_train, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_travel= map2(data_2018,model_travel, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_outside= map2(data_2018,model_outside, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_utilities= map2(data_2018,model_utilities, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_hours= map2(data_2018,model_hours, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

model_pred <- model_pred %>%
  mutate(pred_volume= map2(data_2018,model_volume, function(x, y) 
    map(seq_along(y), function(i) 
      if (i == 3) predict(y[[i]], n.trees = y[[i]]$n.trees)
      else as.numeric(predict(y[[i]])))))

0 个答案:

没有答案