Armadillo中新的`find_finite`函数比循环慢3.5倍?

时间:2014-05-07 12:41:49

标签: r performance rcpp armadillo

Armadillo 4.300中新的find_finitefind_nonfinite功能是很棒的补充!在我使用Rcpp的测试中,它们比标准循环慢约2.5倍。下面是一些代码,用于计算与R na.rm=TRUE选项对应的案例删除的和和均值。 R的性能基准测试表明,与循环相比,第一个版本(sum_armamean_arma)快3.5倍。我做的一切都正确吗?有什么方法可以提高性能吗?

C ++代码

#include <numeric>
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]


// [[Rcpp::export]]
double sum_arma1(arma::mat& X) {
    double sum = 0;
    for (int i = 0; i < X.size(); ++i) {
        if (arma::is_finite(X(i)))
            sum += X(i);
    }
    return sum;
}
// [[Rcpp::export]]
double sum_arma2(arma::mat& X) {
    return arma::sum(X.elem(arma::find_finite(X)));
}

// [[Rcpp::export]]
double mean_arma1(arma::mat& X) {
    double sum = 0;
    int n = 0;
    for (int i = 0; i < X.size(); ++i) {
        if (arma::is_finite(X(i))) {
            sum += X(i);
            n += 1;
        }
    }
    return sum/n;
}
// [[Rcpp::export]]
double mean_arma2(arma::mat& X) {
    return arma::mean(X.elem(arma::find_finite(X)));
}

来自R的基准测试结果

# data
X = matrix(rnorm(1e6),1000,1000)
X[sample(1:1000,100),sample(1:1000,100)] = NA
# equal?
all.equal(sum(X, na.rm=TRUE),sum_arma1(X))
all.equal(sum(X, na.rm=TRUE),sum_arma2(X))
all.equal(mean(X, na.rm=TRUE),mean_arma1(X))
all.equal(mean(X, na.rm=TRUE),mean_arma2(X))

# benchmark
benchmark(
    sum(X, na.rm=TRUE),
    sum_arma1(X),
    sum_arma2(X),
    replications=100)

#                   test replications elapsed relative user.self sys.self
# 2         sum_arma1(X)          100   0.259    1.000     0.259    0.001
# 3         sum_arma2(X)          100   1.035    3.996     0.750    0.293
# 1 sum(X, na.rm = TRUE)          100   0.491    1.896     0.492    0.003

benchmark(
    mean(X, na.rm=TRUE),
    mean_arma1(X),
    mean_arma2(X),
    replications=100)

#                   test replications elapsed relative user.self sys.self
# 2         mean_arma1(X)          100   0.252     1.00     0.253    0.001
# 3         mean_arma2(X)          100   0.819     3.25     0.620    0.206
# 1 mean(X, na.rm = TRUE)          100   7.440    29.52     7.120    0.373

1 个答案:

答案 0 :(得分:2)

一般函数find_finite()find_nonfinite()总是比专门的求和循环慢。 find_finite()并非专门用于求和,而是针对一般情况,即找到有限值的指数。您对这些索引的处理取决于您,并且您已选择将它们用作.elem()函数的输入。

在代码arma::sum(X.elem(arma::find_finite(X)))中,函数find_finite()必须通过X,寻找有限值,并将有限值的结果索引存储在临时向量中。然后.elem()成员函数查看由find_finite()生成的向量,并创建另一个仅包含有限值的向量。反过来,.elem()生成的向量随后由sum()使用。

C ++允许抽象,因此您的代码非常紧凑,但有时您必须为此类抽象付费。一般函数总是比专用循环慢。

然而,对于诸如加法,乘法等算术函数,Armadillo将尝试通过使用智能延迟操作框架(基于模板表达式)来避免生成临时向量/矩阵,该框架排队并组合几个执行前的操作。这减少了临时工的产生。

延迟操作的实现非常复杂,这就是为什么它主要用于最重要的算术功能。但是,Armadillo也会在其他一些情况下使用它,例如,find(X > 123)将避免为X > 123生成临时值。