在Rcpp中使用NumericMatrix和NumericVector的矩阵乘法

时间:2015-02-11 22:27:26

标签: matrix-multiplication rcpp

我想知道有没有一种使用NumericMatrix和NumericVector类计算矩阵乘法的方法。我想知道是否有任何简单的方法 帮助我避免以下循环进行此计算。我只是想计算X%*%beta。

// assume X and beta are initialized and X is of dimension (nsites, p), 
// beta is a NumericVector with p elements. 
for(int j = 0; j < nsites; j++)
 {
    temp = 0;

    for(int l = 0; l < p; l++) temp = temp + X(j,l) * beta[l];

}

非常感谢您提前!

1 个答案:

答案 0 :(得分:5)

根据Dirk的评论,以下是一些通过重载*运算符演示Armadillo库的矩阵乘法的案例:

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export(".mm")]]
arma::mat mm_mult(const arma::mat& lhs,
                  const arma::mat& rhs)
{
  return lhs * rhs;
}

// [[Rcpp::export(".vm")]]
arma::mat vm_mult(const arma::vec& lhs,
                  const arma::mat& rhs)
{
  return lhs.t() * rhs;
}

// [[Rcpp::export(".mv")]]
arma::mat mv_mult(const arma::mat& lhs,
                  const arma::vec& rhs)
{
  return lhs * rhs;
}

// [[Rcpp::export(".vv")]]
arma::mat vv_mult(const arma::vec& lhs,
                  const arma::vec& rhs)
{
  return lhs.t() * rhs;
}

然后,您可以定义一个R函数来调度相应的C ++函数:

`%a*%` <- function(x,y) {

  if (is.matrix(x) && is.matrix(y)) {
    return(.mm(x,y))
  } else if (!is.matrix(x) && is.matrix(y)) {
    return(.vm(x,y))
  } else if (is.matrix(x) && !is.matrix(y)) {
    return(.mv(x,y))
  } else {
    return(.vv(x,y))
  }

}
##
mx <- matrix(1,nrow=3,ncol=3)
vx <- rep(1,3)
my <- matrix(.5,nrow=3,ncol=3)
vy <- rep(.5,3)

与R&#39 {s} %*%功能相比:

R>  mx %a*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5
[2,]  1.5  1.5  1.5
[3,]  1.5  1.5  1.5

R>  mx %*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5
[2,]  1.5  1.5  1.5
[3,]  1.5  1.5  1.5
##
R>  vx %a*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5

R>  vx %*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5
##
R>  mx %a*% vy
     [,1]
[1,]  1.5
[2,]  1.5
[3,]  1.5

R>  mx %*% vy
     [,1]
[1,]  1.5
[2,]  1.5
[3,]  1.5
##
R>  vx %a*% vy
     [,1]
[1,]  1.5

R>  vx %*% vy
     [,1]
[1,]  1.5