如何在R中编写简单矩阵乘法的Rcpp函数

时间:2015-07-28 08:25:59

标签: r rcpp

我已经编写了一个Rcpp代码来计算R中的元素明确矩阵乘法。但是当尝试运行此代码时R停止工作并退出。如何纠正这个功能?

提前致谢。

library(Rcpp)
func <- 'NumericMatrix mmult( NumericMatrix m , NumericMatrix v, bool  byrow=true )
{
if( ! m.nrow() == v.nrow() ) stop("Non-conformable arrays") ;
if( ! m.ncol() == v.ncol() ) stop("Non-conformable arrays") ;

NumericMatrix out(m) ;

for (int i = 1; i <= m.nrow(); i++) 
{
  for (int j = 1; j <= m.ncol(); j++) 
    {
      out(i,j)=m(i,j) * v(i,j) ;
    }
}
return out ;
}'

cppFunction( func )

m1<-matrix(1:4,2,2)
m2<-m1
r1<-mmult(m1,m2)
r2<-m1*m2

2 个答案:

答案 0 :(得分:4)

您必须记住使用0个索引数组。 (请参阅Why does the indexing start with zero in 'C'?Why are zero-based arrays the norm?。)

因此,您需要将循环定义为从0m.nrow() - 1

试试这个:

func <- '
NumericMatrix mmult( NumericMatrix m , NumericMatrix v, bool  byrow=true )
{
  if( ! m.nrow() == v.nrow() ) stop("Non-conformable arrays") ;
  if( ! m.ncol() == v.ncol() ) stop("Non-conformable arrays") ;

  NumericMatrix out(m) ;

  for (int i = 0; i < m.nrow(); i++) 
  {
  for (int j = 0; j < m.ncol(); j++) 
  {
  out(i,j)=m(i,j) * v(i,j) ;
  }
  }
  return out ;
  }
'

然后我得到:

> mmult(m1,m2)
     [,1] [,2]
[1,]    1    9
[2,]    4   16

> m1*m2
     [,1] [,2]
[1,]    1    9
[2,]    4   16

答案 1 :(得分:4)

(至少对我来说)明显的选择是使用RcppArmadillo:

R> cppFunction("arma::mat matmult(arma::mat A, arma::mat B) { return A % B; }", 
+              depends="RcppArmadillo")
R> m1 <- m2 <- matrix(1:4,2,2)
R> matmult(m1,m2)
     [,1] [,2]
[1,]    1    9
[2,]    4   16
R> 

因为犰狳是强类型的,并且有一个逐个元素的乘法算子(%),我们在它所使用的单行中使用。