使用purrr map2函数

时间:2017-07-28 18:13:31

标签: r lm tidyverse purrr broom

我的问题类似于this one,但现在我正在尝试使用具有多个预测变量的模型,而我无法弄清楚如何将新数据转换为预测函数。

library(dplyr)
library(lubridate)
library(purrr)
library(tidyr)
library(broom)

set.seed(1234)

首先我创建一个星期几

wks = seq(as.Date("2010-01-01"), Sys.Date(), by="1 week")

然后我抓住当年

cur_year <- year(Sys.Date())

这里我创建了一个带有虚拟数据的数据框

my_data <- data.frame(
  week_ending = wks
) %>% 
  mutate(
    ref_period = week(week_ending),
    yr = year(week_ending),
    PCT.EXCELLENT = round(runif(length(wks), 0, 100),0),
    PCT.GOOD = round(runif(length(wks), 0, 100),0),
    PCT.FAIR = round(runif(length(wks), 0, 100),0),
    PCT.POOR = round(runif(length(wks), 0, 100),0),
    PCT.VERY.POOR = round(runif(length(wks), 0, 100),0),
    pct_trend = round(runif(length(wks), 75, 125),0)
  )

接下来,我创建一个嵌套的数据框,其中包含一年中每周的数据作为一个组。

cond_model <- my_data %>% 
  filter(yr != cur_year) %>% 
  group_by(ref_period) %>% 
  nest(.key=cond_data) 

在这里,我将今年的数据加入到今年一周的前几年的数据中。

cond_model <- left_join(
  cond_model,
  my_data %>% 
    filter(yr==cur_year) %>% 
    select(week_ending,
           ref_period,
           PCT.EXCELLENT,
           PCT.FAIR,
           PCT.GOOD,
           PCT.POOR,
           PCT.VERY.POOR),
  by = c("ref_period")
) 

这会将线性模型添加到一年中每周的数据框中

cond_model <- 
  cond_model %>% 
  mutate(model = map(cond_data,
                     ~lm(pct_trend ~ PCT.EXCELLENT + PCT.GOOD + PCT.FAIR + PCT.POOR + PCT.VERY.POOR, data = .x)))

现在我想用每周的模型来预测使用今年的数据。我尝试了以下方法:

cond_model <- 
  cond_model %>% 
  mutate(
    pred_pct_trend = map2_dbl(model, PCT.EXCELLENT + PCT.GOOD + PCT.FAIR + PCT.POOR + PCT.VERY.POOR,
                              ~predict(.x, newdata = data.frame(.y)))
  )

这会产生以下错误:

Error in mutate_impl(.data, dots) : object 'PCT.EXCELLENT' not found

然后我尝试在我的数据框中嵌套我的预测变量......

使用今年的数据创建数据框并嵌套预测变量

cur_cond <- my_data %>% 
  filter(yr==cur_year) %>% 
  select(week_ending, PCT.EXCELLENT,
         PCT.GOOD, PCT.FAIR, PCT.POOR, PCT.VERY.POOR) %>% 
  group_by(week_ending) %>% 
  nest(.key=new_data) %>% 
  mutate(new_data=map(new_data, ~data.frame(.x)))

将其加入我的主数据框

cond_model <- left_join(cond_model, cur_cond)

现在我再次尝试预测:

cond_model <- 
  cond_model %>% 
  mutate(
    pred_pct_trend = map2_dbl(model, new_data,
                              ~predict(.x, newdata = data.frame(.y)))
  )

我得到了和以前一样的错误:

Error in mutate_impl(.data, dots) : object 'PCT.EXCELLENT' not found

我认为答案可能涉及对预测变量执行flatten(),但我无法弄清楚我的工作流程在哪里。

cond_model$new_data[1]

VS

flatten_df(cond_model$new_data[1])

此时我已经没有想法了。

1 个答案:

答案 0 :(得分:2)

添加预测数据集后,主要问题是如何处理没有预测数据的周数(第31-53周)。

您将看到加入两个数据集的时间,没有预测数据集的行将填充NULL。您可以使用ifelse语句为这些行预测NA

# Modeling data
cond_model = my_data %>%
    filter(yr != cur_year) %>%
    group_by(ref_period) %>%
    nest(.key = cond_data)

# Create prediction data
cur_cond = my_data %>%
    filter(yr == cur_year) %>% 
    group_by(ref_period) %>% 
    nest( .key = new_data )

# Join these together
cond_model = left_join(cond_model, cur_cond)

# Models
cond_model = cond_model %>% 
    mutate(model = map(cond_data,
                       ~lm(pct_trend ~ PCT.EXCELLENT + PCT.GOOD + 
                               PCT.FAIR + PCT.POOR + PCT.VERY.POOR, data = .x) ) )

如果没有预测数据,请将ifelse放入NA

# Predictions
cond_model %>% 
    mutate(pred_pct_trend = map2_dbl(model, new_data,
                                     ~ifelse(is.null(.y), NA, 
                                             predict(.x, newdata = .y) ) ) )

# A tibble: 53 x 5
   ref_period        cond_data         new_data    model pred_pct_trend
        <dbl>           <list>           <list>   <list>          <dbl>
 1          1 <tibble [7 x 8]> <tibble [1 x 8]> <S3: lm>       83.08899
 2          2 <tibble [7 x 8]> <tibble [1 x 8]> <S3: lm>      114.39089
 3          3 <tibble [7 x 8]> <tibble [1 x 8]> <S3: lm>      215.02055
 4          4 <tibble [7 x 8]> <tibble [1 x 8]> <S3: lm>      130.24556
 5          5 <tibble [7 x 8]> <tibble [1 x 8]> <S3: lm>      112.86516
 6          6 <tibble [7 x 8]> <tibble [1 x 8]> <S3: lm>      107.29866
 7          7 <tibble [7 x 8]> <tibble [1 x 8]> <S3: lm>       52.11526
 8          8 <tibble [7 x 8]> <tibble [1 x 8]> <S3: lm>      106.22482
 9          9 <tibble [7 x 8]> <tibble [1 x 8]> <S3: lm>      128.40858
10         10 <tibble [7 x 8]> <tibble [1 x 8]> <S3: lm>      108.10306