rcpp矩阵乘法速度提高

时间:2015-04-16 20:30:31

标签: performance matrix eigen rcpp armadillo

我正在使用Rcpp进行一些繁重的计算。但是,我的代码运行速度不如预期。我知道我一定做错了,因为在MATLAB中相同的计算速度更快。我很难弄清楚如何改进我的代码,所以我真的在这里寻求帮助。

我有一个名为mX的12乘5000 ^ 2/2矩阵,以及一个名为w的5000 ^ 2/2乘1矩阵。 我想计算A = mX * diag(w)* t(mX)。

我写了以下四个函数来做。

#include <RcppArmadillo.h>
#include <RcppEigen.h>
#include <omp.h>

// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::depends("RcppEigen")]]
// [[Rcpp::plugins(openmp)]]
using namespace Rcpp; 
using Eigen::MatrixXd;


// [[Rcpp::export]]
arma::mat MultiplyArma1(arma::mat mX, arma::colvec w){
  return mX*arma::diagmat(w)*arma::trans(mX);
}

// [[Rcpp::export]]
arma::mat MultiplyArma2(arma::mat mX, arma::colvec w){
  return mX*(arma::repmat(w,1,12)%arma::trans(mX));
}

// [[Rcpp::export]]
arma::mat MultiplyCpp(arma::mat mX, arma::colvec w){ 
    omp_set_num_threads(4);
    int N=mX.n_rows;
    int K=mX.n_cols;
    arma::mat tmX=arma::zeros<arma::mat>(K,N);
    arma::mat ProductF=arma::zeros<arma::mat>(N,N);
    arma::mat ProductM=arma::zeros<arma::mat>(K,N);

    #pragma omp parallel for schedule(static)
    for (int i=0; i<K; ++i){
      for (int j=0; j<N; ++j){     
            tmX(i,j)=mX(j,i);
        }
    }


    #pragma omp parallel for schedule(static)
    for (int j=0; j<N; ++j){
        for (int i=0; i<K; ++i){
            ProductM(i,j)=tmX(i,j)*w(i);
        }
    }

    #pragma omp parallel for schedule(static)
    for (int j=0; j<N; ++j){
        for (int i=0; i<N; ++i){
          for (int k=0; k<K; ++k){
            ProductF(i,j) +=tmX(k,i)*ProductM(k,i);
          }
        }
    }
    return ProductF;
}

// [[Rcpp::export]]
MatrixXd MultiplyEigen(const MatrixXd mX, const MatrixXd Dw){
  return mX* ( Dw.cwiseProduct( mX.transpose() ) );   
}

第1和第2都使用RcppArmadillo,并且只有diag(w)的表示不同。第3个使用openmp,我明智地编码everthing元素。第4次使用RcppEigen。我不知道如何在RcppEigen中生成diag(w),所以输入为Dw = rep(w,1,12)。

mX=matrix(0.7, 12, 5000^2/2)
w=matrix(0.2, 5000^2/2,1)
Dw=repmat(w, 1,12)

tic()
a1=MultiplyArma1(mX,w)
toc()

tic()
a2=MultiplyArma2(mX,w)
toc()

tic()
a3=MultiplyCpp(mX,w)
toc()

tic()
a4=MultiplyEigen(mX, Dw)
toc()

然后,当我计算速度时,它们都需要超过3秒钟。但是,当我在MATLAB A=mX*(repmat(w,1,12).*mX')中进行相同的计算时,只需1.5秒。

这让我觉得仍然有改进我的Rcpp功能的空间,但老实说我不知道​​如何。我非常感谢你的帮助!

0 个答案:

没有答案