其中(x,arr.ind = T)在RCPP中

时间:2013-08-26 02:37:28

标签: rcpp

我在Rcpp或RcppArmadillo中找不到很酷的which(x,arr.ind=T)功能。所以我决定自己快速编写代码。

// [[Rcpp::export]]
arma::umat whicha(arma::mat matrix, int what ){
  arma::uvec outp1;
  int n  =   matrix.n_rows;
  outp1  =   find(matrix==what);
  int nf =   outp1.n_elem;
  arma::mat  out(nf,2);
  arma::vec  foo;
  arma::uvec foo2;
  foo = arma::conv_to<arma::colvec>::from(outp1) +1;  
  foo2 = arma::conv_to<arma::uvec>::from(foo);
  for(int i=0; i<nf; i++){
    out(i,0) = ( foo2(i) %n);
    out(i,1) =  ceil(foo(i) / n ); 
    if(out(i,0)==0) {
      out(i,0)=n;
    }
  }
  return(arma::conv_to<arma::umat>::from(out));
}

代码似乎效率很低,但microbenchmark显示它可能比R的which函数更快。

问题:我是否可以进一步更改此功能以实际准确地再现R which函数,即将MATRIX == something传递给它?现在我需要第二个论点。我只是想方便一点。


更新:修复了一个错误 - 需要ceil而不是floor

如何检查:

ma=floor(abs(rnorm(100,0,6)))
testf=function(k) {all(which(ma==k,arr.ind=T) == whicha(ma,k))} ; sapply(1:10,testf)

基准:

> microbenchmark(which(ma==k,arr.ind=T) , whicha(ma,k))
Unit: microseconds
                        expr    min     lq median     uq    max neval
 which(ma == k, arr.ind = T) 10.264 11.170 11.774 12.377 51.317   100
               whicha(ma, k)  3.623  4.227  4.830  5.133 36.224   100

2 个答案:

答案 0 :(得分:1)

我会通过生成包装器R函数并执行一些丑陋的工作来处理调用来实现此目的。例如,使用您的代码:

whicha.cpp
----------

#include <RcppArmadillo.h>
// [[Rcpp::depends("RcppArmadillo")]]

// [[Rcpp::export]]
arma::umat whicha(arma::mat matrix, int what ){
  arma::uvec outp1;
  int n =   matrix.n_rows;
  outp1 =   find(matrix==what);
  int nf = outp1.n_elem;
  arma::mat out(nf,2);
  arma::vec foo;
  arma::uvec foo2;

  foo = arma::conv_to<arma::vec>::from(outp1) +1;
  out.col(1) = floor(  foo  / n ) +1; 
  foo2 = arma::conv_to<arma::uvec >::from(foo);
  for(int i=0; i<nf; i++){
    out(i,0) =  foo2(i) % n;
  }

  return(arma::conv_to<arma::umat >::from(out));
}

/*** R
whichRcpp <- function(x) {
  call <- match.call()$x
  xx <- eval.parent( call[[2]] )
  what <- eval.parent( call[[3]] )
  return( whicha(xx, what) )
}
x <- matrix(1:1E4, nrow=1E2)
identical( whichRcpp(x == 100L), whicha(x, 100L) ) ## TRUE
microbenchmark::microbenchmark( whichRcpp(x == 100L), whicha(x, 100L) )
*/

不幸的是,microbenchmark告诉我解析调用有点慢:

Unit: microseconds
                 expr    min     lq median      uq    max neval
 whichRcpp(x == 100L) 43.542 44.143 44.443 45.0440 73.271   100
      whicha(x, 100L) 30.029 30.630 30.930 31.2305 78.075   100

你可能值得花时间解析C级别的电话,但我会把它留给你。

答案 1 :(得分:1)

这是我的代码只使用Rcpp:

src <- '
    using namespace std;

    NumericMatrix X(X_);
    double what = as<double>(what_);
    int n_rows = X.nrow();

    NumericVector rows(0);
    NumericVector cols(0);

    for(int ii = 0; ii < n_rows * n_rows; ii++)
    {
        if(X[ii] == what)
        {
            rows.push_back(ii % n_rows + 1);
            cols.push_back(floor(ii / n_rows) + 1);
        }
    }

    return List::create(rows, cols);
'

fun <- inline:::cxxfunction(signature(X_ = 'numeric', what_ = 'numeric'), src, 'Rcpp')

X <- matrix(1:1E4, nrow=1E2)

rbenchmark:::benchmark(fun(X, 100), which(X == 100L, TRUE), columns = c('test', 'replications', 'elapsed', 'relative'), replications = 1000)

                   test replications elapsed relative
1           fun(X, 100)         1000   0.077    1.000
2 which(X == 100, TRUE)         1000   0.100    1.299

microbenchmark:::microbenchmark(fun(X, 100), which(X == 100L, TRUE), times = 1000L)

                   expr    min      lq  median      uq      max neval
            fun(X, 100) 37.372 41.3780 43.6530 48.4825 1650.154  1000
 which(X == 100L, TRUE) 63.366 64.0745 64.3345 64.8240 1911.858  1000

与上一张海报的解决方案相比,没有那么慢。有趣的是,返回数据框而不是列表会显着降低性能。