我试图定义一个可以使用RcppArmadillo
处理稀疏和密集矩阵输入的模板化函数。我得到了一个非常简单的例子,即向C ++发送密集或稀疏矩阵,然后回到R,就像这样工作:
library(inline); library(Rcpp); library(RcppArmadillo)
sourceCpp(code = "
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace Rcpp ;
using namespace arma ;
template <typename T> T importexport_template(const T X) {
T ret = X ;
return ret ;
};
//[[Rcpp::export]]
SEXP importexport(SEXP X) {
return wrap( importexport_template(X) ) ;
}")
library(Matrix)
X <- diag(3)
X_sp <- as(X, "dgCMatrix")
importexport(X)
## [,1] [,2] [,3]
##[1,] 1 0 0
##[2,] 0 1 0
##[3,] 0 0 1
importexport(X_sp)
##3 x 3 sparse Matrix of class "dgCMatrix"
##
##[1,] 1 . .
##[2,] . 1 .
##[3,] . . 1
我认为这意味着模板基本上有效(即,密集的R矩阵变成arma::mat
,而稀疏的R矩阵变成arma::sp_mat
- 对象对Rcpp::as
的隐式调用,以及相应的impliict Rcpp:wrap
然后做正确的事情,并为密集和稀疏返回密集稀疏)。
我尝试编写的实际函数当然需要多个参数,以及我失败的地方 - 做类似的事情:
sourceCpp(code ="
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace Rcpp ;
using namespace arma ;
template <typename T> T scalarmult_template(const T X, double scale) {
T ret = X * scale;
return ret;
};
//[[Rcpp::export]]
SEXP scalarmult(SEXP X, double scale) {
return wrap(scalarmult_template(X, scale) ) ;
}")
失败,因为编译器不知道如何在*
的编译时解析SEXPREC* const
。
所以我想我需要像switch语句in this Rcpp Gallery snippet那样正确地发送到特定的模板函数,但我不知道如何为那些看似比INTSXP
等更复杂的类型编写代码。
我想我知道如何访问我需要这种switch语句的类型,例如:
sourceCpp(code ="
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace Rcpp ;
using namespace arma ;
//[[Rcpp::export]]
SEXP printtype(SEXP Xr) {
Rcpp::Rcout << TYPEOF(Xr) << std::endl ;
return R_NilValue;
}")
printtype(X)
##14
##NULL
printtype(X_sp)
##25
##NULL
但我不明白如何从那里开始。适用于稀疏和密集矩阵的scalarmult_template
版本会是什么样的?
答案 0 :(得分:4)
根据@ KevinUshey的评论回答我自己的问题。我做3个矩阵乘法:密集密集,稀疏密集,&#34; indMatrix&#34; -dense:
library(inline)
library(Rcpp)
library(RcppArmadillo)
library(Matrix)
library(rbenchmark)
sourceCpp(code="
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace Rcpp ;
using namespace arma ;
arma::mat matmult_sp(const arma::sp_mat X, const arma::mat Y){
arma::mat ret = X * Y;
return ret;
};
arma::mat matmult_dense(const arma::mat X, const arma::mat Y){
arma::mat ret = X * Y;
return ret;
};
arma::mat matmult_ind(const SEXP Xr, const arma::mat Y){
// pre-multplication with index matrix is a permutation of Y's rows:
S4 X(Xr);
arma::uvec perm = X.slot("perm");
arma::mat ret = Y.rows(perm - 1);
return ret;
};
//[[Rcpp::export]]
arma::mat matmult_cpp(SEXP Xr, const arma::mat Y) {
if (Rf_isS4(Xr)) {
if(Rf_inherits(Xr, "dgCMatrix")) {
return matmult_sp(as<arma::sp_mat>(Xr), Y) ;
} ;
if(Rf_inherits(Xr, "indMatrix")) {
return matmult_ind(Xr, Y) ;
} ;
stop("unknown class of Xr") ;
} else {
return matmult_dense(as<arma::mat>(Xr), Y) ;
}
}")
n <- 10000
d <- 20
p <- 30
X <- matrix(rnorm(n*d), n, d)
X_sp <- as(diag(n)[,1:d], "dgCMatrix")
X_ind <- as(sample(1:d, n, rep=TRUE), "indMatrix")
Y <- matrix(1:(d*p), d, p)
matmult_cpp(as(X_ind, "ngTMatrix"), Y)
## Error: unknown class of Xr
all.equal(X%*%Y, matmult_cpp(X, Y))
## [1] TRUE
all.equal(as.vector(X_sp%*%Y),
as.vector(matmult_cpp(X_sp, Y)))
## [1] TRUE
all.equal(X_ind%*%Y, matmult_cpp(X_ind, Y))
## [1] TRUE
编辑:这已经变为Rcpp Gallery post。