For-Loop By Columns,具有现有的For-loop by Rows

时间:2018-05-08 15:20:49

标签: r function for-loop dplyr apply

我有一个数据集如下作为样本。我的实际数据集有5000列:

# Define Adstock Rate
adstock_rate = 0.50
lag_number = 3
# Create Data
advertising = c(117.913, 120.112, 125.828, 115.354, 177.090, 141.647, 137.892,   0.000,   0.000,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000,   0.000,   0.000, 158.511, 109.385,  91.084,  79.253, 102.706, 
        78.494, 135.114, 114.549,  87.337, 107.829, 125.020,  82.956,  60.813,  83.149,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000, 129.515, 105.486, 111.494, 107.099,   0.000,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000,   0.000,   0.000,   0.000,   0.000,
        134.913, 123.112, 178.828, 112.354, 100.090, 167.647, 177.892,   0.000,   0.000,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000,   0.000,   0.000, 112.511, 155.385,  123.084,  89.253, 67.706, 
        23.494, 122.114, 112.549,  65.337, 134.829, 123.020,  81.956,  23.813,  65.149,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000, 145.515, 154.486, 121.494, 117.099,   0.000,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000,   0.000,   0.000,   0.000,   0.000
        )

advertising2 = c(43.913, 231.112, 76.828, 22.354, 98.090, 123.647, 90.892,   0.000,   0.000,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000,   0.000,   0.000, 234.511, 143.385,  78.084,  89.253, 12.706, 
        34.494, 56.114, 78.549,  12.337, 67.829, 42.020,  90.956,  23.813,  83.149,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000, 52.515, 76.486, 89.494, 12.099,   0.000,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000,   0.000,   0.000,   0.000,   0.000,
        67.913, 12.112, 45.828, 78.354, 89.090, 90.647, 23.892,   0.000,   0.000,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000,   0.000,   0.000, 78.511, 23.385,  43.084,  67.253, 33.706, 
        56.494, 78.114, 98.549,  45.337, 31.829, 67.020,  87.956,  94.813,  65.149,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000, 55.515, 32.486, 78.494, 33.099,   0.000,   0.000,   0.000, 
        0.000,   0.000,   0.000,   0.000,   0.000,   0.000,   0.000,   0.000
        )
Region = c(500, 500, 500, 500, 500, 500, 500, 500,500, 500, 500, 500,500, 500, 500, 500,500, 500, 500, 500,500, 500, 500, 500,
   500, 500, 500, 500,500, 500, 500, 500,500, 500, 500, 500,500, 500, 500, 500,500, 500, 500, 500,500, 500, 500, 500, 500, 500, 
   500, 500,
   501, 501, 501, 501, 501, 501, 501, 501,501, 501, 501, 501,501, 501, 501, 501,501, 501, 501, 501,501, 501, 501, 501,
   501, 501, 501, 501,501, 501, 501, 501,501, 501, 501, 501,501, 501, 501, 501,501, 501, 501, 501,501, 501, 501, 501, 501, 501, 
   501, 501)

advertising_dataset<-data.frame(cbind(Region, advertising, advertising2))

我的数据集如下所示:

head(advertising_dataset, 15)

   Region advertising advertising2
1     500     117.913       43.913
2     500     120.112      231.112
3     500     125.828       76.828
4     500     115.354       22.354
5     500     177.090       98.090
6     500     141.647      123.647
7     500     137.892       90.892
8     500       0.000        0.000
9     500       0.000        0.000
10    500       0.000        0.000
11    500       0.000        0.000
12    500       0.000        0.000
13    500       0.000        0.000
14    500       0.000        0.000
15    500       0.000        0.000

然后,只为 1 列创建for循环,然后在Region之后创建group_by函数。

foo <- function(df_, lag_val = 1) {
  df_$adstocked_advertising = df_$advertising
  for (i in (1 + lag_val):nrow(df_)) {
    df_$adstocked_advertising[i] = df_$advertising[i] + adstock_rate * 
df_$adstocked_advertising[i - lag_val]
  }
  return(df_)
}


adv_2 <- data.frame(advertising_dataset %>%
                      group_by(Region) %>%
                      do(foo(data.frame(.), lag_val = 3)))    

如何将上述功能(包括adv_2)应用于2:ncol(advertising_dataset)的所有列,而不仅仅是advertising列?

我的最终列数最终会加倍,因为将为每个现有列创建一个新修订的列。

我有一种感觉,就像上面的函数一样:

data.frame(advertising_dataset[1], 
apply(advertising_dataset[2:ncol(advertising_dataset)],2, foo) )

任何帮助都会很棒,谢谢!

2 个答案:

答案 0 :(得分:2)

我们可以将accumulatemutate_all

一起使用
library(tidyverse)
out <- advertising_dataset %>% 
         group_by(Region) %>%
         mutate_all(funs(adstocked = accumulate(., ~ .y + adstock_rate * .x)))
out
# A tibble: 104 x 5
# Groups:   Region [2]
#   Region advertising advertising2 advertising_adstocked advertising2_adstocked
#    <dbl>       <dbl>        <dbl>                 <dbl>                  <dbl>
# 1    500        118.         43.9                 118.                    43.9
# 2    500        120.        231.                  179.                   253. 
# 3    500        126.         76.8                 215.                   203. 
# 4    500        115.         22.4                 223.                   124. 
# 5    500        177.         98.1                 289.                   160. 
# 6    500        142.        124.                  286.                   204. 
# 7    500        138.         90.9                 281.                   193. 
# 8    500          0           0                   140.                    96.4
# 9    500          0           0                    70.2                   48.2
#10    500          0           0                    35.1                   24.1
# ... with 94 more rows

检查OP解决方案的输出

head(out[[4]])
#[1] 117.9130 179.0685 215.3623 223.0351 288.6076 285.9508

head(adv_2[[4]])
#[1] 117.9130 179.0685 215.3623 223.0351 288.6076 285.9508

更新

我们可以针对不同的foo

修改OP的函数lag_val
foo1 <- function(dot, lag_val = 1) {
     tmp <- dot
     for(i in (1 + lag_val): length(tmp)) {
           tmp[i] <- tmp[i] + adstock_rate * tmp[i - lag_val]
     }
     return(tmp)
   }


advertising_dataset %>%
       group_by(Region) %>%
       mutate_all(funs(adstocked = foo1(., lag_val = 1)))
# A tibble: 104 x 5
# Groups:   Region [2]
#   Region advertising advertising2 advertising_adstocked advertising2_adstocked
#    <dbl>       <dbl>        <dbl>                 <dbl>                  <dbl>
# 1    500        118.         43.9                 118.                    43.9
# 2    500        120.        231.                  179.                   253. 
# 3    500        126.         76.8                 215.                   203. 
# 4    500        115.         22.4                 223.                   124. 
# 5    500        177.         98.1                 289.                   160. 
# 6    500        142.        124.                  286.                   204. 
# 7    500        138.         90.9                 281.                   193. 
# 8    500          0           0                   140.                    96.4
# 9    500          0           0                    70.2                   48.2
#10    500          0           0                    35.1                   24.1
# ... with 94 more rows

- 更改lag_val

advertising_dataset %>%
            group_by(Region) %>%
            mutate_all(funs(adstocked = foo1(., lag_val = 2)))
# A tibble: 104 x 5
# Groups:   Region [2]
#   Region advertising advertising2 advertising_adstocked advertising2_adstocked
#    <dbl>       <dbl>        <dbl>                 <dbl>                  <dbl>
# 1    500        118.         43.9                 118.                    43.9
# 2    500        120.        231.                  120.                   231. 
# 3    500        126.         76.8                 185.                    98.8
# 4    500        115.         22.4                 175.                   138. 
# 5    500        177.         98.1                 269.                   147. 
# 6    500        142.        124.                  229.                   193. 
# 7    500        138.         90.9                 273.                   165. 
# 8    500          0           0                   115.                    96.3
# 9    500          0           0                   136.                    82.3
#10    500          0           0                    57.3                   48.2
# ... with 94 more rows

答案 1 :(得分:0)

使用data.table,您可以使用j语句中的lapply将任何函数应用于列的子集。但是,这些函数将作为向量作用于列,因此您可能希望将F1编辑为任何向量的通用。但是,基本上,你所要做的就是:

foo()