我试图找到一种有效的方法来计算X ^ T * W * X,其中X是一个密集的mat
大小,例如10,000 x 10和W是对角矩阵(我只存储vec
中的对角线。)
目前,我使用此功能
arma::mat& getXtW(const arma::mat& covar,
const arma::vec& w,
arma::mat& tcovar,
size_t n, size_t K) {
size_t i, k;
for (i = 0; i < n; i++) {
for (k = 0; k < K; k++) {
tcovar(k, i) = covar(i, k) * w(i);
}
}
return tcovar;
}
并计算
tcovar = getXtW(covar, w, tcovar, n, K);
cprod = tcovar * covar;
然而,这似乎不是最佳的。
PS:你可以看到整个代码there。
编辑1:似乎我可以使用covar.t() * (covar.each_col() % w)
,但这似乎要快得多。
Edit2:如果我自己在Rcpp中使用循环实现它:
arma::mat testProdW2(const arma::mat& x, const arma::vec& w) {
int n = x.n_rows;
int K = x.n_cols;
arma::mat res(K, K);
double tmp;
for (int k = 0; k < K; k++) {
for (int j = k; j < K; j++) {
tmp = 0;
for (int i = 0; i < n; i++) {
tmp += x(i, j) * w[i] * x(i, k);
}
res(j, k) = tmp;
}
}
for (int k = 0; k < K; k++) {
for (int j = 0; j < k; j++) {
res(j, k) = res(k, j);
}
}
return res;
}
这比第一次实施慢。
答案 0 :(得分:2)
根据BLAS matrix by matrix transpose multiply,没有BLAS例程可以直接执行此操作。相反,建议在X
的行上循环并使用dsyr
。我发现这是一个有趣的问题,因为我知道如何link BLAS in Rcpp,但尚未使用RcppArmadillo做到这一点。 Stack Overflow也知道答案:Rcpparmadillo: can't call Fortran routine "dgebal"?。注意:我尚未检查,但我希望dsyr
不属于R随附的BLAS子集。因此,仅当您的R链接到完整的BLAS实现时,此方法才有效。
与此结合,我们得到:
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
#include <Rcpp/Benchmark/Timer.h>
#ifdef ARMA_USE_LAPACK
#if !defined(ARMA_BLAS_CAPITALS)
#define arma_dsyr dsyr
#else
#define arma_dsyr DSYR
#endif
extern "C"
void arma_fortran(arma_dsyr)(const char* uplo, const int* n, const double* alpha, const double* X, const int* incX, const double* A, const int* ldA);
#endif
// [[Rcpp::export]]
Rcpp::NumericVector getXtWX(const arma::mat& X, const arma::vec& w) {
Rcpp::Timer timer;
timer.step("start");
arma::mat result1 = X.t() * (X.each_col() % w);
timer.step("Armadillo result");
const int n = X.n_rows;
const int k = X.n_cols;
arma::mat result(k, k, arma::fill::zeros);
for (size_t i = 0; i < n; ++i) {
F77_CALL(dsyr)("U", &k, &w(i), &X(i,0), &n, result.memptr(), &k);
}
result = arma::symmatu(result);
timer.step("BLAS result");
Rcpp::NumericVector res(timer);
return res;
}
/*** R
n <- 10000
k <- 10
X <- matrix(runif(n*k), n, k)
w <- runif(n)
Reduce(rbind, lapply(1:6, function(x) diff(getXtWX(X, w))/1e6))
*/
但是,对我而言,BLAS解决方案要慢得多:
> Reduce(rbind, lapply(1:6, function(x) diff(getXtWX(X, w))/1e6))
Armadillo result BLAS result
init 1.291243 6.666026
1.176143 6.623282
1.102111 6.644165
1.094917 6.612596
1.098619 6.588431
1.069286 6.615529
我试图通过首先转置矩阵来改善这一点,希望在遍历列矩阵时可以更快地访问内存,但这对我(低功耗)系统没有影响。