性能改善

时间:2013-03-28 20:08:33

标签: r data.table

我正在尝试基于滞后/转发获取功能应用程序。我广泛使用data.table,我甚至有工作代码,但知道data.table的力量我认为必须有一种更简单的方法来实现同样可能提高性能(我做了很多变量的创建)在功能内)。以下是功能代码(可在https://gist.github.com/tomaskrehlik/5262087#file-gistfile1-r中找到)

# Lag-function lags the given variable by the date_variable

lag_variable <- function(data, variable, lags, date_variable = c("Date")) {
    if (lags == 0) {
      return(data)
    }
    if (lags>0) {
      name <- "lag"
    } else {
      name <- "forward"
    }
    require(data.table)
    setkeyv(data, date_variable)
    if (lags>0) {
      data[,index:=seq(1:.N)]  
    } else {
      data[,index:=rev(seq(1:.N))]
    }
    setkeyv(data, "index")
    lags <- abs(lags)
    position <- which(names(data)==variable)
    for ( j in 1:lags ) {
      lagname <- paste(variable,"_",name,j,sep="")
      lag <- paste("data[, ",lagname,":=data[list(index-",j,"), ",variable,", roll=TRUE][[",position,"L]]]", sep = "")
      eval(parse( text = lag ))
    }
    setkeyv(data, date_variable)
    data[,index:=NULL]
}

# window_func applies the function to the lagged or forwarded variables created by lag_variable
window_func <- function(data, func.name, variable, direction = "window", steps, date_variable = c("Date"), clean = TRUE) {
    require(data.table)
    require(stringr)
    transform <- match.fun(func.name)
    l <- length(names(data))
    if (direction == "forward") {
      lag_variable(data, variable, -steps, date_variable)
      cols <- which((!(is.na(str_match(names(a), paste(variable,"_forward(",paste(1:steps,collapse="|"),")",sep=""))[,1])))*1==1)
    } else {
      if (direction == "backward") {
        lag_variable(data, variable, steps, date_variable)
        cols <- which((!(is.na(str_match(names(a), paste(variable,"_lag(",paste(1:steps,collapse="|"),")",sep=""))[,1])))*1==1)
      } else {
        if (direction == "window") {
          lag_variable(data, variable, -steps, date_variable)
          lag_variable(data, variable, steps, date_variable)
          cols <- which((!(is.na(str_match(names(a), paste(variable,"_lag(",paste(1:steps,collapse="|"),")",sep=""))[,1])))*1==1)
          cols <- c(cols,which((!(is.na(str_match(names(a), paste(variable,"_forward(",paste(1:steps,collapse="|"),")",sep=""))[,1])))*1==1))
        } else {
          stop("The direction must be either backward, forward or window.")
        }
      }
    }
    data[,transf := apply(data[,cols, with=FALSE], 1, transform)]
    if (clean) {
      data[,cols:=NULL,with=FALSE]
    }
    return(data)
}

# Typical use:
# I have a data.table DT with variables Date (class IDate), value1, value2
# I want to get cumulative sum of next five days
# window_func(DT, "sum", "value1", direction = "forward", steps = 5)

编辑:可以通过以下方式创建示例数据:

a <- data.table(Date = 1:1000, value = rnorm(1000))

对于每个日期(这里,只是一个例子的整数,并不重要),我想创建下一个十个观察的总和。要运行代码并获取输出,请执行以下操作:

window_func(data = a, func.name = "sum", variable = "value", 
      direction = "forward", steps = 10, date_variable = "Date", clean = TRUE)

该函数首先获取变量并创建十个滞后变量(使用函数lag_variable),然后逐列应用函数并自行清理。代码膨胀是因为我有时只需要在滞后观察时使用函数,有时候需要在前向观察上使用函数,有时候在两者上使用函数,这称为窗口。

有任何建议如何更好地实现这一点?我的代码似乎太大了。

1 个答案:

答案 0 :(得分:5)

我不确定你的其他功能,但你可以按照以下方式有效地获得滞后金额:

a[ , lagSum := 
       a[, list(sum=sum(value)), by=list(grp=(Date+lag-i) %/% lag)] [grp!=0, sum]
   , by=list(i=Date %% lag)]

例如:

set.seed(1)
a[ , lagSum := 
       a[, list(sum=sum(value)), by=list(grp=(Date+lag-i) %/% lag)] [grp!=0, sum]
   , by=list(i=Date %% lag)]

> a
      Date      value      lagSum
   1:    1 -0.6264538  1.32202781
   2:    2  0.1836433  3.46026279
   3:    3 -0.8356286  3.66646270
   4:    4  1.5952808  3.88085074
   5:    5  0.3295078  0.07087005
  ---                            
 996:  996 -0.3132929 -3.79332038
 997:  997 -0.8806707 -3.48002750
 998:  998 -0.4192869 -2.59935677
 999:  999 -1.4827517 -2.18006988
1000: 1000 -0.6973182 -1.88854602

确认正确的值:

# first n values
n <- 5
for (i in seq(n))
  a[seq(i, length.out=10), print(sum(value))]

#  [1] 1.322028
#  [1] 3.460263
#  [1] 3.666463
#  [1] 3.880851
#  [1] 0.07087005

基准(反对for循环,所以不太公平)

set.seed(1)
a <- data.table(Date = 1:1000, value = rnorm(1000))

system.time({    a[ , lagSum := 
           a[, list(sum=sum(value)), by=list(grp=(Date+lag-i) %/% lag)] [grp!=0, sum]
       , by=list(i=Date %% lag)]
})

#  user  system elapsed 
# 0.049   0.001   0.056 



set.seed(1)
a <- data.table(Date = 1:1000, value = rnorm(1000))

system.time({    for (i in seq(nrow(a)-lag+1))
      a[seq(i, length.out=10), lagSum := sum(value)]})

#  user  system elapsed 
# 1.526   0.019   2.220