将列向量乘以RcppArmadillo中的数字标量

时间:2013-08-21 15:28:05

标签: r matrix rcpp armadillo

使用c++Rcpp包编译这个简单的RcppArmadillo代码时遇到了一些问题。使用以下简单示例将矩阵的每列乘以数字标量:

code <- 'arma::mat out = Rcpp::as<arma::mat>(m);
for(int i = 0; i < out.n_cols; ++i){
  out.col(i) *= v;
}
return Rcpp::wrap( out );'

尝试使用...

进行编译
require( RcppArmadillo )
armMult <- cxxfunction( signature( m = "numeric" , v = "numeric" ),
                        code , plugin = "RcppArmadillo" )

导致编译错误....

#error: no match for 'operator*=' in 'arma::Mat<eT>::col(arma::uword) [with eT = double, arma::uword = unsigned int](((unsigned int)i)) *= v'

但是,如果我们将numeric变量v替换为2.0,如下所示....

code <- 'arma::mat out = Rcpp::as<arma::mat>(m);
for(int i = 0; i < out.n_cols; ++i){
  out.col(i) *= 2.0; //Notice we use 2.0 instead of a variable
}
return Rcpp::wrap( out );'

它编译得很好......

armMult <- cxxfunction( signature(m="numeric"),
                        code,plugin="RcppArmadillo")

然后我们可以做......

m <- matrix( 1:4 , 2 , 2 )

armMult( m )
     [,1] [,2]
[1,]    2    6
[2,]    4    8

我在这里缺少什么?如何使用简单的数字标量来完成此工作。我希望能够通过像......这样的标量。

armMult( m , 2.0 )

并返回与上面相同的结果。

3 个答案:

答案 0 :(得分:8)

如果您想将矩阵 A 的每列乘以向量 x 的对应元素,请尝试以下操作:

Rcpp:::cppFunction(
    "arma::mat fun(arma::mat A, arma::rowvec x) 
    { 
        A.each_row() %= x;
        return A;
    }", depends = "RcppArmadillo"
)

fun(matrix(rep(1, 6), 3, 2), c(5, 1))

     [,1] [,2]
[1,]    5    1
[2,]    5    1
[3,]    5    1

答案 1 :(得分:2)

每当我抓住这样的问题时,我首先要减少问题。尝试使用Armadillo标头的C ++三线程。让它工作,然后把它移到RcppArmadillo。

编辑:一个人可以做得比你的答案更好,因为你不需要单独地对每一列进行乘法(尽管可以)。无论如何,这只是展示 Rcpp属性

> cppFunction("arma::mat simon(arma::mat m, double v) { return m * v;}", 
+             depends="RcppArmadillo")
> simon(matrix(1:4,2,2), 3)
     [,1] [,2]
[1,]    3    9
[2,]    6   12
> 

答案 2 :(得分:1)

感谢@DirkEddelbuettel的评论,这只是因为我没有定义v ......

code <- '
arma::mat out = Rcpp::as<arma::mat>(m);
double scl = Rcpp::as<double>(v);
for(int i = 0; i < out.n_cols; ++i){
  out.col(i) *= scl;
}
return Rcpp::wrap( out );
'

armMult <- cxxfunction( signature( m = "numeric" , v = "numeric" ),
                        code , plugin = "RcppArmadillo" )

armMult( m , 2.0 )
     [,1] [,2]
[1,]    2    6
[2,]    4    8