犰狳的高效加权交叉产品

时间:2018-01-23 21:47:00

标签: performance rcpp armadillo

我试图找到一种有效的方法来计算X ^ T * W * X,其中X是一个密集的mat大小,例如10,000 x 10和W是对角矩阵(我只存储vec中的对角线。)

目前,我使用此功能

arma::mat& getXtW(const arma::mat& covar,
                  const arma::vec& w,
                  arma::mat& tcovar,
                  size_t n, size_t K) {
  size_t i, k;

  for (i = 0; i < n; i++) {
    for (k = 0; k < K; k++) {
      tcovar(k, i) = covar(i, k) * w(i);
    }
  }

  return tcovar;
}

并计算

tcovar = getXtW(covar, w, tcovar, n, K);
cprod = tcovar * covar;

然而,这似乎不是最佳的。

PS:你可以看到整个代码there

编辑1:似乎我可以使用covar.t() * (covar.each_col() % w),但这似乎要快得多。

Edit2:如果我自己在Rcpp中使用循环实现它:

arma::mat testProdW2(const arma::mat& x, const arma::vec& w) {

  int n = x.n_rows;
  int K = x.n_cols;
  arma::mat res(K, K);
  double tmp;
  for (int k = 0; k < K; k++) {
    for (int j = k; j < K; j++) {
      tmp = 0;
      for (int i = 0; i < n; i++) {
        tmp += x(i, j) * w[i] * x(i, k);
      }
      res(j, k) = tmp;
    }
  }

  for (int k = 0; k < K; k++) {
    for (int j = 0; j < k; j++) {
      res(j, k) = res(k, j);
    }
  }

  return res;
}

这比第一次实施慢。

1 个答案:

答案 0 :(得分:2)

根据BLAS matrix by matrix transpose multiply,没有BLAS例程可以直接执行此操作。相反,建议在X的行上循环并使用dsyr。我发现这是一个有趣的问题,因为我知道如何link BLAS in Rcpp,但尚未使用RcppArmadillo做到这一点。 Stack Overflow也知道答案:Rcpparmadillo: can't call Fortran routine "dgebal"?。注意:我尚未检查,但我希望dsyr不属于R随附的BLAS子集。因此,仅当您的R链接到完整的BLAS实现时,此方法才有效。

与此结合,我们得到:

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
#include <Rcpp/Benchmark/Timer.h>

#ifdef ARMA_USE_LAPACK
#if !defined(ARMA_BLAS_CAPITALS)
 #define arma_dsyr dsyr
#else
 #define arma_dsyr DSYR
#endif

extern "C"
void arma_fortran(arma_dsyr)(const char* uplo, const int* n, const double* alpha, const double* X, const int* incX, const double* A, const int* ldA);

#endif


// [[Rcpp::export]]
Rcpp::NumericVector getXtWX(const arma::mat& X, const arma::vec& w) {
  Rcpp::Timer timer;
  timer.step("start"); 

  arma::mat result1 = X.t() * (X.each_col() % w);
  timer.step("Armadillo result");

  const int n = X.n_rows;
  const int k = X.n_cols;
  arma::mat result(k, k, arma::fill::zeros);  
  for (size_t i = 0; i < n; ++i) {
    F77_CALL(dsyr)("U", &k, &w(i), &X(i,0), &n, result.memptr(), &k);
  }
  result = arma::symmatu(result);
  timer.step("BLAS result");
  Rcpp::NumericVector res(timer);
  return res;
}

/*** R
n <- 10000
k <- 10
X <- matrix(runif(n*k), n, k)
w <- runif(n)
Reduce(rbind, lapply(1:6, function(x) diff(getXtWX(X, w))/1e6))
*/

但是,对我而言,BLAS解决方案要慢得多:

> Reduce(rbind, lapply(1:6, function(x) diff(getXtWX(X, w))/1e6))
     Armadillo result BLAS result
init         1.291243    6.666026
             1.176143    6.623282
             1.102111    6.644165
             1.094917    6.612596
             1.098619    6.588431
             1.069286    6.615529

我试图通过首先转置矩阵来改善这一点,希望在遍历列矩阵时可以更快地访问内存,但这对我(低功耗)系统没有影响。