RcppArmadillo中稀疏和密集矩阵的模板化函数

时间:2014-03-19 17:23:06

标签: c++ r rcpp armadillo

我试图定义一个可以使用RcppArmadillo处理稀疏和密集矩阵输入的模板化函数。我得到了一个非常简单的例子,即向C ++发送密集或稀疏矩阵,然后回到R,就像这样工作:

library(inline); library(Rcpp); library(RcppArmadillo)

sourceCpp(code =    "
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace Rcpp ;
using namespace arma ;

template <typename T> T importexport_template(const T X) {
    T ret = X ;
    return ret ;
};

//[[Rcpp::export]]
SEXP importexport(SEXP X) {
    return wrap( importexport_template(X) ) ;
}")

library(Matrix)
X <- diag(3)
X_sp <- as(X, "dgCMatrix")

importexport(X)
##     [,1] [,2] [,3]
##[1,]    1    0    0
##[2,]    0    1    0
##[3,]    0    0    1
importexport(X_sp)
##3 x 3 sparse Matrix of class "dgCMatrix"
##          
##[1,] 1 . .
##[2,] . 1 .
##[3,] . . 1

我认为这意味着模板基本上有效(即,密集的R矩阵变成arma::mat,而稀疏的R矩阵变成arma::sp_mat - 对象对Rcpp::as的隐式调用,以及相应的impliict Rcpp:wrap然后做正确的事情,并为密集和稀疏返回密集稀疏)。

我尝试编写的实际函数当然需要多个参数,以及我失败的地方 - 做类似的事情:

sourceCpp(code ="
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

using namespace Rcpp ;
using namespace arma ;

template <typename T> T scalarmult_template(const T X, double scale) {
    T ret = X * scale;
    return ret;
};

//[[Rcpp::export]]
SEXP scalarmult(SEXP X, double scale) {
    return wrap(scalarmult_template(X, scale) ) ;
}")

失败,因为编译器不知道如何在*的编译时解析SEXPREC* const。 所以我想我需要像switch语句in this Rcpp Gallery snippet那样正确地发送到特定的模板函数,但我不知道如何为那些看似比INTSXP等更复杂的类型编写代码。

我想我知道如何访问我需要这种switch语句的类型,例如:

sourceCpp(code ="
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

using namespace Rcpp ;
using namespace arma ;

//[[Rcpp::export]]
SEXP printtype(SEXP Xr) {
    Rcpp::Rcout << TYPEOF(Xr) << std::endl ;
    return R_NilValue;
}")
printtype(X)
##14
##NULL
printtype(X_sp)
##25
##NULL

但我不明白如何从那里开始。适用于稀疏和密集矩阵的scalarmult_template版本会是什么样的?

1 个答案:

答案 0 :(得分:4)

根据@ KevinUshey的评论回答我自己的问题。我做3个矩阵乘法:密集密集,稀疏密集,&#34; indMatrix&#34; -dense:

library(inline)
library(Rcpp)
library(RcppArmadillo)
library(Matrix)
library(rbenchmark)

sourceCpp(code="
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

using namespace Rcpp ;
using namespace arma ;

arma::mat matmult_sp(const arma::sp_mat X, const arma::mat Y){
    arma::mat ret = X * Y;
    return ret;
};
arma::mat matmult_dense(const arma::mat X, const arma::mat Y){
    arma::mat ret = X * Y;
    return ret;
};
arma::mat matmult_ind(const SEXP Xr, const arma::mat Y){
    // pre-multplication with index matrix is a permutation of Y's rows: 
    S4 X(Xr);
    arma::uvec perm =  X.slot("perm");
    arma::mat ret = Y.rows(perm - 1);
    return ret;
};

//[[Rcpp::export]]
arma::mat matmult_cpp(SEXP Xr, const arma::mat Y) {
    if (Rf_isS4(Xr)) {
        if(Rf_inherits(Xr, "dgCMatrix")) {
            return matmult_sp(as<arma::sp_mat>(Xr), Y) ;
        } ;
        if(Rf_inherits(Xr, "indMatrix")) {
            return matmult_ind(Xr, Y) ; 
        } ;
        stop("unknown class of Xr") ;
    } else {
        return matmult_dense(as<arma::mat>(Xr), Y) ;
    } 
}")

n <- 10000
d <- 20
p <- 30  

X <- matrix(rnorm(n*d), n, d)
X_sp <- as(diag(n)[,1:d], "dgCMatrix")
X_ind <- as(sample(1:d, n, rep=TRUE), "indMatrix")
Y <- matrix(1:(d*p), d, p)

matmult_cpp(as(X_ind, "ngTMatrix"), Y)
## Error: unknown class of Xr

all.equal(X%*%Y, matmult_cpp(X, Y))
## [1] TRUE

all.equal(as.vector(X_sp%*%Y), 
          as.vector(matmult_cpp(X_sp, Y)))
## [1] TRUE

all.equal(X_ind%*%Y, matmult_cpp(X_ind, Y))
## [1] TRUE

编辑:这已经变为Rcpp Gallery post