使用列表列中的lm来使用purrr预测新值

时间:2017-06-22 21:46:52

标签: r tidyverse purrr broom

我正在尝试将一列预测添加到具有包含lm模型的列表列的数据框中。我采用了this post中的一些代码。

我在这里做了一个玩具示例:

library(dplyr)
library(purrr)
library(tidyr)
library(broom)

set.seed(1234)

exampleTable <- data.frame(
  ind = c(rep(1:5, 5)),
  dep = rnorm(25),
  groups = rep(LETTERS[1:5], each = 5)
) %>%
group_by(groups) %>%
nest(.key=the_data) %>%
mutate(model = the_data %>% map(~lm(dep ~ ind, data = .))) %>%
mutate(Pred = map2(model, the_data, predict))

exampleTable <- exampleTable %>%
  mutate(ind=row_number())

这给了我一个看起来像这样的东西:

# A tibble: 5 × 6
  groups         the_data    model      Pred   ind 
  <fctr>           <list>   <list>    <list> <int> 
1      A <tibble [5 × 2]> <S3: lm> <dbl [5]>     1 
2      B <tibble [5 × 2]> <S3: lm> <dbl [5]>     2 
3      C <tibble [5 × 2]> <S3: lm> <dbl [5]>     3 
4      D <tibble [5 × 2]> <S3: lm> <dbl [5]>     4 
5      E <tibble [5 × 2]> <S3: lm> <dbl [5]>     5 

使用lm模型获取特定组的预测值我可以使用:

predict(exampleTable[1,]$model[[1]], slice(exampleTable, 1) %>% select(ind))

产生这个结果:

> predict(exampleTable[1,]$model[[1]], slice(exampleTable, 1) %>% select(ind))
         1 
-0.4822045

我想为每个小组预测一个新的预测。我尝试使用purrr来获得我想要的东西:

exampleTable %>%
  mutate(Prediction = map2(model, ind, predict))

但是会出现以下错误:

Error in mutate_impl(.data, dots) : object 'ind' not found

我能够通过以下怪物获得我想要的结果:

exampleTable$Prediction <- NA

for(loop in seq_along(exampleTable$groups)){
  lmod <- exampleTable[loop, ]$model[[1]]
  obs <- filter(exampleTable, row_number()==loop) %>%
    select(ind)
  exampleTable[loop, ] $Prediction <- as.numeric(predict(lmod, obs))
}

这给了我一个看起来像这样的东西:

# A tibble: 5 × 6
  groups         the_data    model      Pred   ind Prediction
  <fctr>           <list>   <list>    <list> <int>      <dbl>
1      A <tibble [5 × 2]> <S3: lm> <dbl [5]>     1 -0.4822045
2      B <tibble [5 × 2]> <S3: lm> <dbl [5]>     2 -0.1357712
3      C <tibble [5 × 2]> <S3: lm> <dbl [5]>     3 -0.2455760
4      D <tibble [5 × 2]> <S3: lm> <dbl [5]>     4  0.4818425
5      E <tibble [5 × 2]> <S3: lm> <dbl [5]>     5 -0.3473236

必须有办法在“整洁”中做到这一点。方式,但我不能破解它。

1 个答案:

答案 0 :(得分:2)

您可以利用newdata的{​​{1}}参数。

我使用predict因此它只返回单个值而不是列表。

map2_dbl

如果您在预测之前将mutate(Pred = map2_dbl(model, 1:5, ~predict(.x, newdata = data.frame(ind = .y)))) # A tibble: 5 x 4 groups the_data model Pred <fctr> <list> <list> <dbl> 1 A <tibble [5 x 2]> <S3: lm> -0.4822045 2 B <tibble [5 x 2]> <S3: lm> -0.1357712 3 C <tibble [5 x 2]> <S3: lm> -0.2455760 4 D <tibble [5 x 2]> <S3: lm> 0.4818425 5 E <tibble [5 x 2]> <S3: lm> -0.3473236 添加到数据集,则可以使用该列而不是ind

1:5