将数据集拆分为列表并使用lm模型R

时间:2017-12-20 17:47:54

标签: r split lapply lm

我尝试使用caret包在我的数据集中应用lm模型。

可重复的例子:

df <- data.frame(x = 1:10000, y = sample(1:1000, 10000, replace = TRUE), group = sample(c('A', 'B', 'C'), 10000, replace = TRUE, prob = c(.1, .5, .4)))

df_list <- split(df, df$group)

df_list <- lapply(df_list, function(x) select(x, -group))

创建数据分区会引发错误。我想使用caret&#39; createDataPartition对数据进行分区,然后应用train函数。

train_test <- lapply(df_list, function(x) createDataPartition(x, p = .8, list = FALSE))

model_list <- lapply(train_test, function(z) train(x ~ ., z, method = 'lm', trControl = trainControl(method = 'cv', number = 10, verboseIter = TRUE), preProcess = c('nzv', 'center', 'scale'))

我认为这是解决列表结构的一个简单问题,但由于某种原因,我遇到了问题。感谢帮助!

4 个答案:

答案 0 :(得分:1)

createDataPartition接受一个向量,而不是数据帧:

train_test <- lapply(df_list, function(x) createDataPartition(x$y, p = .8, list = FALSE))

答案 1 :(得分:1)

我认为分区错误是由于createDataPartition需要向量而不是数据帧。我想你可以做到:

train_test <- lapply(df_list, function(x) {
  x[createDataPartition(x$x, p = 0.8, list = FALSE),]
})

然后你的model_list <- ...块为我工作。

据我所知,这不应该搞砸你的索引:

set.seed(123)
df_small <- data.frame(x = runif(10), y = letters[1:10])
df_small_part <- df_small[createDataPartition(df_small$x, list = FALSE),]

> join(df_small, df_small_part, type = "left", by = "y")
           x y         x
1  0.2875775 a 0.2875775
2  0.7883051 b        NA
3  0.4089769 c        NA
4  0.8830174 d 0.8830174
5  0.9404673 e 0.9404673
6  0.0455565 f 0.0455565
7  0.5281055 g        NA
8  0.8924190 h        NA
9  0.5514350 i 0.5514350
10 0.4566147 j 0.4566147

答案 2 :(得分:1)

如果在控制台中键入?createDataPartition,则可以看到该功能的正确用法。

也就是说,它的通用格式如下:

createDataPartition(y, times = 1, p = 0.5, list = TRUE, groups = min(5,
  length(y)))

其中y是&#34;结果的矢量&#34;。它需要特定结果的原因是为了使结果变量(我假设在你的情况下为y)平衡训练和测试分裂。

因此,而不是您拥有的以下代码:

train_test <- lapply(df_list, function(x) createDataPartition(x, p = .8, list = FALSE))

将其替换为以下内容:

train_test <- lapply(df_list, function(x) { 
  return(createDataPartition(x$y, p = .8, list = FALSE))
  })

要明确的是,唯一的修改是添加了$y

然而,这会导致你的最后一行(你在lapply()train()函数的行)的另一个错误。你看,createDataPartition()返回用于你的数据帧的INDEXES。换句话说,要在df_list中获取每个df的训练集,您必须使用例如(df_list[[1]])[train_test[[1]],]。随后,要获得相应的测试集,您必须使用例如(df_list[[1]])[-train_test[[1]],](注意减号符号)。因此,您应该将最后一行重写为以下内容:

model_list <- purrr::map2(df_list, train_test, 
                          function(df, train_index)  {
                            train(x ~ ., df[train_index,], 
                                  method = 'lm', 
                                  trControl = trainControl(method = 'cv', 
                                                           number = 10, 
                                                           verboseIter = TRUE), 
                                  preProcess = c('nzv', 'center', 'scale')) 
                            })

请注意,purrr的map2函数类似于sapply / lapply(其中sapply / lapply为列表中的每个元素调用一个函数)。唯一的区别是map2迭代 2 列表(df_list和train_test)。

我希望这有帮助!

编辑:如果您想了解更多关于插入符号包的信息,建议您使用以下链接:http://topepo.github.io/caret/data-splitting.html

答案 3 :(得分:1)

这是purrr列表列tidyverse - 符合Jenny Bryan的解决方案。请提供您的意见,如何让它更清洁。

library(dplyr)
library(tidyr)
library(purrr)

df <- data.frame(x = 1:10000, y = sample(1:1000, 10000, replace = TRUE), 
                 group = sample(c('A', 'B', 'C'), 10000, replace = TRUE, prob = c(.1, .5, .4)))

df %>% group_by(group) %>% nest() %>% 
  mutate(dataPart = map(data, ~caret::createDataPartition(.x$x, p = .8, list = FALSE) )) %>% 
  mutate(model_list = map2(data, dataPart, ~caret::train(x ~ ., 
                                      data=.x[.y,], 
                                      method = 'lm', 
                                      trControl = caret::trainControl(method = 'cv', number = 10, verboseIter = TRUE), 
                                      preProcess = c('nzv', 'center', 'scale'))),
         oof_prediction=pmap(list(data, dataPart, model_list), ~caret::predict.train(..3, newdata=..1[-..2, ])),
         oof_error=pmap(list(data, dataPart, oof_prediction), ~caret::postResample(..3, ..1$x[-..2])),
         oof_error=map(oof_error, ~as.data.frame(t(.x)))) %>% 
  unnest(oof_error)
  

data.frame中会发生什么,保留在data.frame中 - Hadley Wickham

# A tibble: 3 x 7
   group                 data          dataPart  model_list oof_prediction     RMSE     Rsquared
  <fctr>               <list>            <list>      <list>         <list>    <dbl>        <dbl>
1      C <tibble [3,971 x 2]> <int [3,179 x 1]> <S3: train>    <dbl [792]> 2902.691 2.386907e-05
2      B <tibble [5,041 x 2]> <int [4,033 x 1]> <S3: train>  <dbl [1,008]> 2832.764 3.075320e-04
3      A   <tibble [988 x 2]>   <int [792 x 1]> <S3: train>    <dbl [196]> 2861.664 3.438135e-03