我在Rcpp或RcppArmadillo中找不到很酷的which(x,arr.ind=T)
功能。所以我决定自己快速编写代码。
// [[Rcpp::export]]
arma::umat whicha(arma::mat matrix, int what ){
arma::uvec outp1;
int n = matrix.n_rows;
outp1 = find(matrix==what);
int nf = outp1.n_elem;
arma::mat out(nf,2);
arma::vec foo;
arma::uvec foo2;
foo = arma::conv_to<arma::colvec>::from(outp1) +1;
foo2 = arma::conv_to<arma::uvec>::from(foo);
for(int i=0; i<nf; i++){
out(i,0) = ( foo2(i) %n);
out(i,1) = ceil(foo(i) / n );
if(out(i,0)==0) {
out(i,0)=n;
}
}
return(arma::conv_to<arma::umat>::from(out));
}
代码似乎效率很低,但microbenchmark
显示它可能比R的which
函数更快。
问题:我是否可以进一步更改此功能以实际准确地再现R which
函数,即将MATRIX == something
传递给它?现在我需要第二个论点。我只是想方便一点。
更新:修复了一个错误 - 需要ceil而不是floor
如何检查:
ma=floor(abs(rnorm(100,0,6)))
testf=function(k) {all(which(ma==k,arr.ind=T) == whicha(ma,k))} ; sapply(1:10,testf)
基准:
> microbenchmark(which(ma==k,arr.ind=T) , whicha(ma,k))
Unit: microseconds
expr min lq median uq max neval
which(ma == k, arr.ind = T) 10.264 11.170 11.774 12.377 51.317 100
whicha(ma, k) 3.623 4.227 4.830 5.133 36.224 100
答案 0 :(得分:1)
我会通过生成包装器R函数并执行一些丑陋的工作来处理调用来实现此目的。例如,使用您的代码:
whicha.cpp
----------
#include <RcppArmadillo.h>
// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::umat whicha(arma::mat matrix, int what ){
arma::uvec outp1;
int n = matrix.n_rows;
outp1 = find(matrix==what);
int nf = outp1.n_elem;
arma::mat out(nf,2);
arma::vec foo;
arma::uvec foo2;
foo = arma::conv_to<arma::vec>::from(outp1) +1;
out.col(1) = floor( foo / n ) +1;
foo2 = arma::conv_to<arma::uvec >::from(foo);
for(int i=0; i<nf; i++){
out(i,0) = foo2(i) % n;
}
return(arma::conv_to<arma::umat >::from(out));
}
/*** R
whichRcpp <- function(x) {
call <- match.call()$x
xx <- eval.parent( call[[2]] )
what <- eval.parent( call[[3]] )
return( whicha(xx, what) )
}
x <- matrix(1:1E4, nrow=1E2)
identical( whichRcpp(x == 100L), whicha(x, 100L) ) ## TRUE
microbenchmark::microbenchmark( whichRcpp(x == 100L), whicha(x, 100L) )
*/
不幸的是,microbenchmark
告诉我解析调用有点慢:
Unit: microseconds
expr min lq median uq max neval
whichRcpp(x == 100L) 43.542 44.143 44.443 45.0440 73.271 100
whicha(x, 100L) 30.029 30.630 30.930 31.2305 78.075 100
你可能值得花时间解析C级别的电话,但我会把它留给你。
答案 1 :(得分:1)
这是我的代码只使用Rcpp:
src <- '
using namespace std;
NumericMatrix X(X_);
double what = as<double>(what_);
int n_rows = X.nrow();
NumericVector rows(0);
NumericVector cols(0);
for(int ii = 0; ii < n_rows * n_rows; ii++)
{
if(X[ii] == what)
{
rows.push_back(ii % n_rows + 1);
cols.push_back(floor(ii / n_rows) + 1);
}
}
return List::create(rows, cols);
'
fun <- inline:::cxxfunction(signature(X_ = 'numeric', what_ = 'numeric'), src, 'Rcpp')
X <- matrix(1:1E4, nrow=1E2)
rbenchmark:::benchmark(fun(X, 100), which(X == 100L, TRUE), columns = c('test', 'replications', 'elapsed', 'relative'), replications = 1000)
test replications elapsed relative
1 fun(X, 100) 1000 0.077 1.000
2 which(X == 100, TRUE) 1000 0.100 1.299
microbenchmark:::microbenchmark(fun(X, 100), which(X == 100L, TRUE), times = 1000L)
expr min lq median uq max neval
fun(X, 100) 37.372 41.3780 43.6530 48.4825 1650.154 1000
which(X == 100L, TRUE) 63.366 64.0745 64.3345 64.8240 1911.858 1000
与上一张海报的解决方案相比,没有那么慢。有趣的是,返回数据框而不是列表会显着降低性能。