表格的Rcpp版本较慢;这是从哪里来的,如何理解

时间:2015-06-23 11:20:44

标签: r rcpp

在为已经聚合的数据创建一些采样函数的过程中,我发现表格在我使用的大小数据上相当慢。我尝试了两项改进,首先是Rcpp函数,如下所示

// [[Rcpp::export]]
IntegerVector getcts(NumericVector x, int m) {
  IntegerVector cts(m);
  int t;
  for (int i = 0; i < x.length(); i++) {
    t = x[i] - 1;
    if (0 <= t && t < m)
      cts[t]++;
  }
  return cts;
}

然后在试图理解为什么表格相当慢的时候,我发现它基于制表。 Tabulate对我来说效果很好,并且比Rcpp版本更快。制表的代码位于:

https://github.com/wch/r-source/blob/545d365bd0485e5f0913a7d609c2c21d1f43145a/src/main/util.c#L2204

关键是:

for(R_xlen_t i = 0 ; i < n ; i++)
  if (x[i] != NA_INTEGER && x[i] > 0 && x[i] <= nb) y[x[i] - 1]++;

现在制表和我的Rcpp版本的关键部分看起来非常接近(我没有打扰过NA)。

Q1:为什么我的Rcpp版本慢了3倍?

Q2:我怎样才能知道这个时间到了哪里?

我非常感谢知道时间的去向,但更好的方法是分析代码。我的C ++技能只是如此,但这似乎很简单,我应该(交叉我的手指)能够避免任何会使我的时间增加三倍的愚蠢的东西。

我的时间码:

max_x <- 100
xs <- sample(seq(max_x), size = 50000000, replace = TRUE)
system.time(getcts(xs, max_x))
system.time(tabulate(xs))

这为得分提供0.318,为表格提供0.126。

1 个答案:

答案 0 :(得分:5)

您的函数在每次循环迭代中调用length方法。似乎编译器没有缓存它。要在单独的变量中修复此向量的存储大小,或使用基于范围的循环。另请注意,我们并不需要明确的缺失值检查,因为在C ++中,涉及NaN的所有比较始终返回false

让我们比较一下表现:

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
IntegerVector tabulate1(const IntegerVector& x, const unsigned max) {
    IntegerVector counts(max);
    for (std::size_t i = 0; i < x.size(); i++) {
        if (x[i] > 0 && x[i] <= max)
            counts[x[i] - 1]++;
    }
    return counts;
}

// [[Rcpp::export]]
IntegerVector tabulate2(const IntegerVector& x, const unsigned max) {
    IntegerVector counts(max);
    std::size_t n = x.size();
    for (std::size_t i = 0; i < n; i++) {
        if (x[i] > 0 && x[i] <= max)
            counts[x[i] - 1]++;
    }
    return counts;
}

// [[Rcpp::plugins(cpp11)]]
// [[Rcpp::export]]
IntegerVector tabulate3(const IntegerVector& x, const unsigned max) {
    IntegerVector counts(max);
    for (auto& now : x) {
        if (now > 0 && now <= max)
            counts[now - 1]++;
    }
    return counts;
}

// [[Rcpp::plugins(cpp11)]]
// [[Rcpp::export]]
IntegerVector tabulate4(const IntegerVector& x, const unsigned max) {
    IntegerVector counts(max);
    for (auto it = x.begin(); it != x.end(); it++) {
        if (*it > 0 && *it <= max)
            counts[*it - 1]++;
    }
    return counts;
}

/***R
library(microbenchmark)
x <- sample(10, 1e5, rep = TRUE)
microbenchmark(
    tabulate(x, 10), tabulate1(x, 10),
    tabulate2(x, 10), tabulate3(x, 10), tabulate4(x, 10)
)
x[sample(10e5, 10e3)] <- NA
microbenchmark(
    tabulate(x, 10), tabulate1(x, 10),
    tabulate2(x, 10), tabulate3(x, 10), tabulate4(x, 10)
)
*/

tabulate1是原始版本。

基准测试结果:

没有NA

Unit: microseconds
            expr     min       lq     mean   median      uq     max neval
 tabulate(x, 10) 143.557 146.8355 169.2820 156.1970 177.327 286.370   100
tabulate1(x, 10) 390.706 392.6045 437.7357 416.5655 443.065 748.767   100
tabulate2(x, 10) 108.149 111.4345 139.7579 118.2735 153.118 337.647   100
tabulate3(x, 10) 107.879 111.7305 138.2711 118.8650 139.598 300.023   100
tabulate4(x, 10) 391.003 393.4530 436.3063 420.1915 444.048 777.862   100

使用NA

Unit: microseconds
            expr      min        lq     mean   median       uq      max neval
 tabulate(x, 10)  943.555 1089.5200 1614.804 1333.806 2042.320 3986.836   100
tabulate1(x, 10) 4523.076 4787.3745 5258.490 4929.586 5624.098 7233.029   100
tabulate2(x, 10)  765.102  931.9935 1361.747 1113.550 1679.024 3436.356   100
tabulate3(x, 10)  773.358  914.4980 1350.164 1140.018 1642.354 3633.429   100
tabulate4(x, 10) 4241.025 4466.8735 4933.672 4717.016 5148.842 8603.838   100

使用迭代器的tabulate4函数也慢于tabulate。我们可以改进它:

// [[Rcpp::plugins(cpp11)]]
// [[Rcpp::export]]
IntegerVector tabulate4(const IntegerVector& x, const unsigned max) {
    IntegerVector counts(max);
    auto start = x.begin();
    auto end = x.end();
    for (auto it = start; it != end; it++) {
        if (*(it) > 0 && *(it) <= max)
            counts[*(it) - 1]++;
    }
    return counts;
}