在第一个非NA值之后递归填充元素

时间:2015-06-11 15:10:48

标签: r performance for-loop recursion

数据示例:

   >w
     date         V1         V2         V3
    1     1         NA         NA         NA
    2     2         NA         NA         NA
    3     3 -0.2357066         NA -0.5428883
    4     4         NA         NA         NA
    5     5         NA -0.4333103         NA
    6     6         NA         NA         NA
    7     7 -0.6494716  0.7267507  1.1519118
    8     8         NA         NA         NA
    9     9         NA         NA         NA
    10   10         NA         NA         NA

> r
   date           V1         V2          V3
1     1  1.262954285  0.7635935 -0.22426789
2     2 -0.326233361 -0.7990092  0.37739565
3     3  1.329799263 -1.1476570  0.13333636
4     4  1.272429321 -0.2894616  0.80418951
5     5  0.414641434 -0.2992151 -0.05710677
6     6 -1.539950042 -0.4115108  0.50360797
7     7 -0.928567035  0.2522234  1.08576936
8     8 -0.294720447 -0.8919211 -0.69095384
9     9 -0.005767173  0.4356833 -1.28459935
10   10  2.404653389 -1.2375384  0.04672617

我正在尝试使用以下规则填充ww(t+1) <- w(t)*r(t),但仅限于第一个非NA元素之后的值。 for循环等价物是:

for (i in 1:(nrow(w)-1)) {
  for (j in 2:ncol(w)){
    if (is.na(w[i+1,j])) {
      w[i+1,j] <- w[i,j]*r[i,j]
    }
  }
}

并给出:

  > w
   date           V1          V2           V3
1     1           NA          NA           NA
2     2           NA          NA           NA
3     3 -0.235706556          NA -0.542888255
4     4 -0.313442405          NA -0.072386744
5     5 -0.398833307 -0.43331032 -0.058212660
6     6 -0.165372814  0.12965300  0.003324337
7     7 -0.649471647  0.72675075  1.151911754
8     8  0.603077961  0.18330358  1.250710490
9     9 -0.177739406 -0.16349234 -0.864183216
10   10  0.001025054 -0.07123088  1.110129201

这有点类似于cumprod,但是我被卡住了。是否可以避免for循环(或至少其中一个),以便加快速度?

数据可以通过以下方式复制:

set.seed(0)
r <- as.data.frame(matrix(data = rnorm(30), nrow = 10, ncol = 3))
w <- as.data.frame(matrix(data = NA, nrow =10, ncol = 3))

w[3, c(1,3)] <- rnorm(2)
w[5, 2] <- rnorm(1)
w[7,] <- rnorm(ncol(w))
date <- 1:nrow(w)
w <- cbind(date, w)
r <- cbind(date, r)

2 个答案:

答案 0 :(得分:3)

如果您有几列,可以按照data.table操作替换内部循环。

library(data.table) # v1.9.5
fdt <- function(w, r){  
  for (j in 2:ncol(w)){
    w[,j] <- data.table(x=r[, j], z=w[, j])[,ifelse(is.na(z), z[1L]*shift(cumprod(x)), z), cumsum(!is.na(z))][,V1]
  }  
  w
}

对于包含100000行的数据框,我的计算机上大约需要3秒。

w <- do.call('rbind', lapply(1:10000, function(i)w))
r <- do.call('rbind', lapply(1:10000, function(i)r))
system.time(fdt(w, r))
#user  system elapsed 
#2.923   0.004   2.928 

然而,嵌套循环需要200秒

system.time(f(w, r))
#   user  system elapsed 
#206.406   0.043 206.559 
f <- function(w, r){
  for (i in 1:(nrow(w)-1)) {
    for (j in 2:ncol(w)){
      if (is.na(w[i+1,j])) {
        w[i+1,j] <- w[i,j]*r[i,j]
      }
    }
  }
  w
}

<强> [编辑]

dplyr版本的运行速度略快于fd

library(dplyr)
fdp <- function(w, r){    
  for (j in 2:ncol(w)){
    d <- data_frame(x=r[, j], z=w[, j]) %>% 
      group_by(cumsum(!is.na(z))) %>% 
      mutate(v=ifelse(is.na(z), z[1L]*lag(cumprod(x)), z))
    w[, j] <- d$v
  }    
  w
}
system.time(fdp(w, r))
#   user  system elapsed 
#  2.458   0.008   2.467 

<强> [EDIT2]

对于几百万行,data.table解决方案仍然很慢。你可以使用Rcpp很好地加快速度。

Rcpp::cppFunction('NumericMatrix fill(NumericMatrix w, NumericMatrix r) {
  for (int i = 0; i < w.nrow()-1; i++) {
    for (int j = 0; j < w.ncol(); j++) {
      if (NumericVector::is_na(w(i+1,j))) {
        w(i+1,j) = w(i,j)*r(i,j);
      }
    }    
  }
  return w;
}') 

对于1M行,只需不到一秒钟。

system.time(fill(as.matrix(w[,-1]), as.matrix(r[,-1])))
 #  user  system elapsed 
 # 0.913   0.004   0.917 

答案 1 :(得分:1)

这是另一种方法:

    library(zoo)

cumprodsplit <- function(col, r, w){

  # fill the NAs
  fill_w <- na.locf(w)[[col]]

  # groups
  f <- cumsum(!is.na(w[[col]]))

  # split w
  splits <- split(fill_w, f)

  # generate the cumprods
  cumprods <- lapply(split(r[[col]], f), 
                           function(x) c(1, cumprod(x)[-length(x)]))
                     # multiply
  vec <- mapply('*', splits, cumprods, SIMPLIFY = FALSE)

                     #unlist
  setNames(data.frame(unlist(vec, use.names = FALSE)), col)
}


do.call("cbind", lapply(names(w)[-1], cumprodsplit, r, w))

             V1          V2           V3
1            NA          NA           NA
2            NA          NA           NA
3  -0.235706556          NA -0.542888255
4  -0.313442405          NA -0.072386744
5  -0.398833307 -0.43331032 -0.058212660
6  -0.165372814  0.12965300  0.003324337
7  -0.649471647  0.72675075  1.151911754
8   0.603077961  0.18330358  1.250710490
9  -0.177739406 -0.16349234 -0.864183216
10  0.001025054 -0.07123088  1.110129201