提高R中嵌套For循环的性能

时间:2015-03-27 07:35:50

标签: r performance for-loop

以下是具有158K观测值的大数据帧的子集,名为“sh_data”。

Patient_ID Age_in_days DEMAdmNo 
396076 28542 0 
396076 28570 0 
396076 28598 0 
396076 28626 0 
396076 28654 0 
396076 28682 0 
396076 28710 0 
396076 28738 0 
396076 28766 0 
396076 28794 0 
396076 28822 0 
396076 28850 0 
396076 28878 0 
396076 28906 0 
396076 28934 0 
396076 28962 0 
396076 28990 0 
396076 29018 0 
396076 29046 0 
396076 29074 0 
396076 29102 1 
396076 29165 0 
396076 29200 0 
396076 29228 0 
396076 29263 0 
396076 29200 0 
396076 29228 0 
396076 29263 0 

我正在尝试计算过去六个月中第三列为1的记录的实例数(表示为LACE_E)。所以对于第一个记录,年龄最小的时候它将为零。并且对于第二记录,如果天的年龄差异<= 183天并且第一记录的列3为零则那么它将是一个,依此类推。

我在R中发出以下查询:

LACE_E <- numeric(0)

for(i in 1:length(sh_data[,1]))
{
  LACE_E[i] = 0
  for(j in 1:length(sh_data[,1]))
  {
    if(sh_data$Patient_ID[i] == sh_data$Patient_ID[j] & sh_data$Age_in_days[i] > sh_data$Age_in_days[j] & (sh_data$Age_in_days[i]- sh_data$Age_in_days[j])<= 183 & sh_data$DEMAdmNo[j] == 1)
    {LACE_E[i] = LACE_E[i] + 1}
  }
}

此查询需要很长时间才能处理。在我的系统中处理100行1小时。请帮助!!

2 个答案:

答案 0 :(得分:5)

您无需转到Rcppdata.table即可获得非常显着的改善。

获取原始数据并进行复制以获得更多可用时间:

d <- read.table(head = TRUE, text = 
"Patient_ID Age_in_days DEMAdmNo 
396076 28542 0 
396076 28570 0 
396076 28598 0 
396076 28626 0 
396076 28654 0 
396076 28682 0 
396076 28710 0 
396076 28738 0 
396076 28766 0 
396076 28794 0 
396076 28822 0 
396076 28850 0 
396076 28878 0 
396076 28906 0 
396076 28934 0 
396076 28962 0 
396076 28990 0 
396076 29018 0 
396076 29046 0 
396076 29074 0 
396076 29102 1 
396076 29165 0 
396076 29200 0 
396076 29228 0 
396076 29263 0 
396076 29200 0 
396076 29228 0 
396076 29263 0 ")

d <- rbind(d, d, d, d, d, d, d, d, d, d)

您的原始代码作为函数和计时运行:

f0 <- function(sh_data) {
    LACE_E <- numeric(0)

    for(i in 1:length(sh_data[,1])) {
        LACE_E[i] = 0
        for(j in 1:length(sh_data[,1])) {
            if(sh_data$Patient_ID[i] == sh_data$Patient_ID[j] &
               sh_data$Age_in_days[i] > sh_data$Age_in_days[j] &
               (sh_data$Age_in_days[i]- sh_data$Age_in_days[j])<= 183 &
               sh_data$DEMAdmNo[j] == 1) {
                LACE_E[i] = LACE_E[i] + 1
            }
        }
    }
}

system.time(v0 <- f0(d))
##   user  system elapsed 
##  4.803   0.007   4.812 

分析显示在内循环中使用$提取列的时间大约为90%:

Rprof()
v0 <- f0(d)
Rprof(NULL)
head(summaryRprof()$by.total)
## "f0"                  4.94    100.00      0.60    12.15
## "$"                   4.24     85.83      0.72    14.57
## "$.data.frame"        3.52     71.26      0.36     7.29
## "[["                  3.16     63.97      0.46     9.31
## "[[.data.frame"       2.70     54.66      0.96    19.43
## "%in%"                0.92     18.62      0.22     4.45

将色谱柱提取移出循环会大大提高性能:

f1 <- function(sh_data) {
    LACE_E <- numeric(0)

    Patient_ID <- sh_data$Patient_ID
    Age_in_days <- sh_data$Age_in_days
    DEMAdmNo <- sh_data$DEMAdmNo
    for(i in 1:length(sh_data[,1])) {
        LACE_E[i] = 0
        for(j in 1:length(sh_data[,1])) {
            if(Patient_ID[i] == Patient_ID[j] &
               Age_in_days[i] > Age_in_days[j] &
               (Age_in_days[i]- Age_in_days[j])<= 183 &
               DEMAdmNo[j] == 1) {
                LACE_E[i] = LACE_E[i] + 1
            }
        }
    }
}

system.time(v1 <- f1(d))
##   user  system elapsed 
##  0.163   0.000   0.164 

从一个空洞的结果开始并发展它几乎总是一个坏主意;预先分配结果是更好的做法。在这种情况下,算法已经O(n^2),所以你没有注意到,但它确实有所作为,特别是在添加其他改进之后。 f2预先分配结果:

f2 <- function(sh_data) {
    n <- nrow(sh_data)
    LACE_E <- numeric(n)

    Patient_ID <- sh_data$Patient_ID
    Age_in_days <- sh_data$Age_in_days
    DEMAdmNo <- sh_data$DEMAdmNo
    for(i in 1:n) {
        LACE_E[i] = 0
        for(j in 1:n) {
            if(Patient_ID[i] == Patient_ID[j] &
               Age_in_days[i] > Age_in_days[j] &
               (Age_in_days[i]- Age_in_days[j])<= 183 &
               DEMAdmNo[j] == 1) {
                LACE_E[i] = LACE_E[i] + 1
            }
        }
    }
}

system.time(v2 <- f2(d))
##   user  system elapsed 
##  0.147   0.000   0.148 

使用正确的逻辑运算符&&代替&可以进一步改进:

f3 <- function(sh_data) {
    n <- nrow(sh_data)
    LACE_E <- numeric(n)

    Patient_ID <- sh_data$Patient_ID
    Age_in_days <- sh_data$Age_in_days
    DEMAdmNo <- sh_data$DEMAdmNo
    for(i in 1:n) {
        LACE_E[i] = 0
        for(j in 1:n) {
            if(Patient_ID[i] == Patient_ID[j] &&
               Age_in_days[i] > Age_in_days[j] &&
               (Age_in_days[i] - Age_in_days[j]) <= 183 &&
               DEMAdmNo[j] == 1) {
                LACE_E[i] = LACE_E[i] + 1
            }
        }
    }
}

system.time(v3 <- f3(d))
##   user  system elapsed 
##  0.108   0.002   0.111 

这些是您前往Rcpp所需采取的所有步骤,但您无需前往Rcpp即可。

要获得更高的速度,您可以进行字节编译:

f3c <- compiler::cmpfun(f3)
system.time(v3 <- f3c(d))
##   user  system elapsed 
## 0.036   0.000   0.036 

这些计算在R 3.1.3中完成。 microbenchmark摘要:

microbenchmark(f0(d), f1(d), f2(d), f3(d), f3c(d), times = 10)
## Unit: milliseconds
##   expr        min        lq       mean     median         uq        max  neval  cld
##   f0(d) 5909.39756 5924.8493 5963.63608 5947.23469 6011.94567 6048.03571    10    d
##   f1(d)  196.16466  197.3252  200.22471  197.93345  202.49236  210.22011    10   c 
##   f2(d)  187.68169  190.5644  194.02454  192.47596  195.63821  204.27415    10   c 
##   f3(d)  109.17816  110.6695  112.55218  111.93915  114.43341  116.92342    10  b  
##  f3c(d)   37.37348   38.8757   39.34564   39.58563   40.50597   40.58568    10 a

R.version$version.string
## [1] "R version 3.1.3 Patched (2015-03-16 r68072)"

将于4月发布的R 3.2.0对解释器和字节码引擎进行了一些改进,进一步提高了性能:

## Unit: milliseconds
##    expr        min         lq       mean     median         uq        max neval  cld
##   f0(d) 4351.33908 4429.71559 4471.32960 4479.13901 4499.39769 4601.05390    10    d
##   f1(d)  183.57765  184.68961  190.10391  187.30951  199.56235  200.57238    10   c 
##   f2(d)  177.47063  181.09790  189.78291  185.58951  190.34782  233.90264    10   c 
##   f3(d)  105.79767  108.02553  114.48950  110.17056  112.85710  149.42474    10  b  
##  f3c(d)   14.41182   14.43227   14.70098   14.49289   14.84504   15.67709    10 a   

R.version$version.string
## [1] "R Under development (unstable) (2015-03-24 r68072)"

如此优秀的R编程实践和性能分析工具的使用可以带您走很长的路。如果您想进一步改进,可以转到Rcpp,但这可能足以达到您的目的。

答案 1 :(得分:2)

我认为使用Rcppdata.table可以更好地实现这一目标。对于这个问题,你真的不需要在R中进行for循环。

我建议采用以下方法?

创建一个新的source.cpp文件,如下所示(示例目录为C:\ Projects)

#include <Rcpp.h>
using namespace Rcpp;
// [[Rcpp::export]]
List myFunction(NumericVector x,NumericVector y) {
  const int n(x.size());
  NumericVector res(n);
  // x is age_in_days
  // y is DEMAAdmNo
  for (int i=1; i<n; i++)  {
       res[i]=0;
       for (int j=1; j<j; j++) {
            if ( (x[i]>x[j]) & ((x[i]-x[j])<=183) & (y[j]==1)) {
            res[i]=res[i]+1;
            }
       }
  }
  return Rcpp::List::create(_["res"] = res);
}

如果您没有安装Rcpp软件包,请执行此操作并加载上面创建的cpp文件,如下所示:

Rcpp::sourceCpp('C:/Projects/source.cpp')

然后,在您的主文件中,执行以下操作:

library(data.table) #If not installed, do install.packages('data.table')
sh_data=fread('C:/Projects/data3.csv') #Please put your correct file path here
sh_data[, LACE_E := myFunction(Age_in_days, DEMAdmNo), by=Patient_ID]

我无法验证数字,因为您还没有指定所需的输出,因此请调整if文件中的cpp语句。

在任何情况下,Rcppdata.table的组合将为您节省大量时间。强烈推荐。

希望这有帮助。