计算n组中每个可能的矢量分裂的统计量

时间:2015-01-23 21:26:08

标签: r performance apply

我想使用一些蛮力来计算每个可能的矢量分割的自定义统计量。我想出了如何创建所有拆分(使用partitions包),我还设法计算了我需要的东西。但是,我的代码很慢,我没有看到任何明显的机会来加速它。我怎样才能让它更快?随着obs(数据框中的行数)的增加,执行时间呈指数增长:

#install.packages('partitions', dependencies = TRUE)
library(partitions)

fa <- function(obs) {
# Sample Data
tmpDf <- data.frame(x = seq(obs), 
                    y1 = trunc(runif(obs) * 1000),
                    y2 = trunc(runif(obs) * 1000)
                    )
# Partitions for given data (from 1 up to 9 splits)
partitions <- restrictedparts(obs, 10, include.zero = TRUE)
# Stat for every split
splitsStat <- apply(partitions, 2, function(part) {
  # Calculate indexes of splits based on partitions
  tmp <- cumsum(part[part != 0])
  # Last element is always equal to obs, has to be removed
  tmp <- tmp[-length(tmp)]
  # Add ids of splits to tmpDf data frame
  if(length(tmp) == 0) {
    tmpDf$ints <- 1
  } else if(length(tmp) == 1 ) {
    tmpDf$ints <- ifelse(tmpDf$x > tmp, 1, 0)      
  } else {
    tmpDf$ints <- cut(tmpDf$x, breaks = tmp, labels = FALSE)  
  }
  # I need to aggregate by splits to calculate my statistic
  out  <- aggregate(cbind(y1, y2) ~ ints, data = tmpDf, sum)
  # Calculate statistic
  sum(log(out$y1 / out$y2) * ((out$y1 / sum(out$y1)) - (out$y2 / sum(out$y2))))
  }
)
}

# This takes around a minute to calculate on modern laptop
library(microbenchmark)
microbenchmark(
  fa(5),  fa(10), fa(15), fa(20), fa(30),  fa(40),  times = 1)

结果:

Unit: milliseconds
   expr         min          lq        mean      median          uq         max neval
  fa(5)    11.64077    11.64077    11.64077    11.64077    11.64077    11.64077     1
 fa(10)    70.50710    70.50710    70.50710    70.50710    70.50710    70.50710     1
 fa(15)   318.19676   318.19676   318.19676   318.19676   318.19676   318.19676     1
 fa(20)   890.54962   890.54962   890.54962   890.54962   890.54962   890.54962     1
 fa(30)  6382.75802  6382.75802  6382.75802  6382.75802  6382.75802  6382.75802     1
 fa(40) 29703.39809 29703.39809 29703.39809 29703.39809 29703.39809 29703.39809     1

2 个答案:

答案 0 :(得分:3)

这一行

# Partitions for given data (from 1 up to 9 splits)
partitions <- restrictedparts(obs, 10, include.zero = TRUE)

返回一个包含大量列的矩阵,这些列随着obs呈指数增长。对于较小的obs值,这不是问题。但对于obs>100

ncol(restrictedparts(100, 10))
# 6292069

 aa<-restrictedparts(60, 10,include.zero = TRUE)
 microbenchmark(apply(aa,2,sum), times=1)
#Unit: milliseconds
 #             expr      min       lq     mean   median       uq      max neval

#     apply(aa, 2, sum) 333.3406 333.3406 333.3406 333.3406 333.3406 333.3406     1

 aa<-restrictedparts(100, 10,include.zero = TRUE)
 microbenchmark(apply(aa,2,sum), times=1)
#Unit: seconds
#              expr      min       lq     mean   median       uq      max neval
#   apply(aa, 2, sum) 27.60511 27.60511 27.60511 27.60511 27.60511 27.60511     1

如果使用此方法,您的执行时间将在obs中呈指数级增长(使用加法整数分区的枚举)。当然,要求apply操作大量具有非平凡功能的列也会影响问题(正如评论已经提到cutaggregate - 我指出为什么& #34; obs&#34;指数增长。)

答案 1 :(得分:0)

事实证明,最大的问题在于我的方法:

  • 我需要在顺序很重要的向量的每一个可能的分裂 - 我必须使用compositions函数。在这里,它变得更快 - 对于100个元素和6个分裂,有71,523,144个可能的分裂。
  • 我为所有可能的单一分割预先计算了我的函数值(对于100个元素,这很容易存储在100x100矩阵中,使用名称statMat)。从这个矩阵我可以说通过简单的子集来从/向索引给出的统计数据的价值 - 我想这可以节省很多时间。
  • 现在,使用生成的合成(对于6个分割,6行和71,523,144列的矩阵),我可以创建累积列和以获得分割的索引值(这是另一个Rcpp函数)。
  • 最终的Rccp函数是一个for循环,我只是用略微模糊的索引定义对结果求和,以从预先计算的矩阵中得到部分函数结果。

它在我的电脑上运行6.3秒。

Rcpp功能:

NumericVector part(IntegerMatrix x, NumericMatrix statMat) {
  int nrow = x.nrow();
  int ncol = x.ncol();
  NumericVector out(ncol);
  for (int j = 0; j < ncol; j++) {
    double total = 0; 
      /* Loop through split vector */
      for (int i = 0; i < nrow; i++) {
        if(i == 0) {
          /* For first split, we need range from 0 up to split value - 1 (because index)
             starts at 0 */
          total += statMat(0,x(i,j)-1);
        } else {
          /* For next splits, we need range from value of previous split + 1. This is actually
             a value of previous split, as we have indexed shifted by one (ufff). The upper bound
             is value of current split -1 (to adjust index). */
          total += statMat(x(i-1,j),x(i,j)-1);
        }
      }
    out[j] = total;
  }
  return out;
}