滞后项的两列之和

时间:2018-07-27 03:28:27

标签: r dplyr data.table

假设我有一个像这样的data.frame:

dt=data.frame(id=rep(letters[1:3],each=4),
          year=rep(1:4,3),
          invest=1:12,
          y=rep(c(1,0,0,0),3)) 
dt
   id year invest y
1   a    1      1 1
2   a    2      2 0
3   a    3      3 0
4   a    4      4 0
5   b    1      5 1
6   b    2      6 0
7   b    3      7 0
8   b    4      8 0
9   c    1      9 1
10  c    2     10 0
11  c    3     11 0
12  c    4     12 0

我想获得一个新列y2:y2 = lag.y2 * 0.8 +投资,第一年的y2按组等于y。 像这样:

id  year invest y   y2
a   1    1      1   1
a   2    2      0   2.8
a   3    3      0   5.24
a   4    4      0   8.192
b   1    5      1   1
b   2    6      0   6.8
b   3    7      0   12.44
b   4    8      0   17.952
c   1    9      1   1
c   2    10     0   10.8
c   3    11     0   19.64
c   4    12     0   27.712

谢谢!

2 个答案:

答案 0 :(得分:5)

您可以使用Reduce来递归计算,如下所示:

library(data.table)
setDT(dt)[, y2 := Reduce(function(y, inv) inv + y * 0.8, 
        invest[-1L], init=y[1L], accumulate=TRUE), 
    by=.(id)]

或使用Rcpp进一步提高效率:

library(Rcpp)
cppFunction(
"NumericVector func(double y0, NumericVector invest) {
    NumericVector res(invest.size());
    res[0] = y0;
    for (int i=1; i<invest.size(); i++) {
        res[i] = 0.8*res[i-1] + invest[i];
    }
    return res;
}")
dt$y2 <- unlist(by(dt, dt$id, function(x) func(x$y[1L], x$invest)))

输出:

    id year invest y     y2
 1:  a    1      1 1  1.000
 2:  a    2      2 0  2.800
 3:  a    3      3 0  5.240
 4:  a    4      4 0  8.192
 5:  b    1      5 1  1.000
 6:  b    2      6 0  6.800
 7:  b    3      7 0 12.440
 8:  b    4      8 0 17.952
 9:  c    1      9 1  1.000
10:  c    2     10 0 10.800
11:  c    3     11 0 19.640
12:  c    4     12 0 27.712

答案 1 :(得分:3)

一个tidyverse选项将会

library(tidyverse)
out <- dt %>%
         group_by(id) %>% 
         mutate(y2 = accumulate(invest[-1], ~ .x * 0.8 + .y, .init = y[1]))
out
# A tibble: 12 x 5
# Groups:   id [3]
#   id     year invest     y    y2
#   <fct> <int>  <int> <dbl> <dbl>
# 1 a         1      1     1  1   
# 2 a         2      2     0  2.8 
# 3 a         3      3     0  5.24
# 4 a         4      4     0  8.19
# 5 b         1      5     1  1   
# 6 b         2      6     0  6.8 
# 7 b         3      7     0 12.4 
# 8 b         4      8     0 18.0 
# 9 c         1      9     1  1   
#10 c         2     10     0 10.8 
#11 c         3     11     0 19.6 
#12 c         4     12     0 27.7 

注意:tbl_df打印输出显示一些值的舍入。如果我们提取列,它将给出正确的输出

out$y2
#[1]  1.000  2.800  5.240  8.192  1.000  6.800 12.440 17.952  1.000 10.800
#[11] 19.640 27.712