尝试使用RcppArmadillo编写setdiff()函数会产生编译错误

时间:2015-04-18 23:29:55

标签: c++ r rcpp armadillo

我试图使用RcppArmadillo在C ++中编写一种R的setdiff()函数的类比。我粗略的做法:

  // [[Rcpp::export]]
  arma::uvec my_setdiff(arma::uvec x, arma::uvec y){
  // Coefficientes of unsigned integer vector y form a subset of the coefficients of unsigned integer vector x.
  // Returns set difference between the coefficients of x and those of y
  int n2 = y.n_elem;
  uword q1;
  for (int j=0 ; j<n2 ; j++){
    q1 = find(x==y[j]);
    x.shed_row(q1);
  }
  return x;
  }

在编译时失败。错误如下:

fnsauxarma.cpp:622:29: error: no matching function for call to ‘arma::Col<double>::shed_row(const arma::mtOp<unsigned int, arma::mtOp<unsigned int, arma::Col<double>, arma::op_rel_eq>,     arma::op_find>)’

我真的不知道发生了什么,非常感谢任何帮助或评论。

3 个答案:

答案 0 :(得分:3)

问题在于arma::find返回uvec,并且不知道如何将隐式转换为arma::uword,正如@mtall所指出的那样。您可以使用模板arma::conv_to<T>::from()函数来帮助编译器。另外,我添加了my_setdiff的另一个版本,它返回Rcpp::NumericVector,因为虽然第一个版本返回正确的值,但它在技术上是matrix(即它有维度),并且我假设您希望这与R&#39 {s} setdiff尽可能兼容。这是通过使用dimNULL成员函数将返回向量的R_NilValue属性设置为Rcpp::attr来完成的。


#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::uvec my_setdiff(arma::uvec& x, const arma::uvec& y){

  for (size_t j = 0; j < y.n_elem; j++) {
    arma::uword q1 = arma::conv_to<arma::uword>::from(arma::find(x == y[j]));
    x.shed_row(q1);
  }
  return x;
}

// [[Rcpp::export]]
Rcpp::NumericVector my_setdiff2(arma::uvec& x, const arma::uvec& y){

  for (size_t j = 0; j < y.n_elem; j++) {
    arma::uword q1 = arma::conv_to<arma::uword>::from(arma::find(x == y[j]));
    x.shed_row(q1);
  }

  Rcpp::NumericVector x2 = Rcpp::wrap(x);
  x2.attr("dim") = R_NilValue;
  return x2;
}

/*** R
x <- 1:8
y <- 2:6

R> all.equal(setdiff(x,y), my_setdiff(x,y))
#[1] "Attributes: < target is NULL, current is list >" "target is numeric, current is matrix"           

R> all.equal(setdiff(x,y), my_setdiff2(x,y))
#[1] TRUE

R> setdiff(x,y)
#[1] 1 7 8

R> my_setdiff(x,y)
# [,1]
# [1,]    1
# [2,]    7
# [3,]    8

R> my_setdiff2(x,y)
#[1] 1 7 8

*/

修改 为了完整起见,这里有一个比上面提到的两个实现更强大的setdiff版本:

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

// [[Rcpp::export]]
Rcpp::NumericVector arma_setdiff(arma::uvec& x, arma::uvec& y){

    x = arma::unique(x);
    y = arma::unique(y);

    for (size_t j = 0; j < y.n_elem; j++) {
        arma::uvec q1 = arma::find(x == y[j]);
        if (!q1.empty()) {
            x.shed_row(q1(0));
        }
    }

    Rcpp::NumericVector x2 = Rcpp::wrap(x);
    x2.attr("dim") = R_NilValue;
    return x2;
}

/*** R

x <- 1:10
y <- 2:8

R> all.equal(setdiff(x,y), arma_setdiff(x,y))
#[1] TRUE

X <- 1:6
Y <- c(2,2,3)

R> all.equal(setdiff(X,Y), arma_setdiff(X,Y))
#[1] TRUE
*/

如果您使用非唯一元素传递向量,则以前的版本会抛出错误,例如

R> my_setdiff2(X,Y)

error: conv_to(): given object doesn't have exactly one element

为了解决问题并更接近R {'setdiff,我们只需xy。此外,我使用arma::conv_to<>::from(其中q1(0)现在是q1而不是uvec)切换了uword,因为uvec&# 39; s只是uword s的向量,显式演员看起来有点不雅。

答案 1 :(得分:1)

我使用了来自STL的std::set_difference,来自arma :: uvec的来回转换。

<label>

编辑:我认为性能比较可能是有序的。当集合的相对大小顺序相反时,差异会变小。

#include <RcppArmadillo.h>
#include <algorithm>

// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::uvec std_setdiff(arma::uvec& x, arma::uvec& y) {

  std::vector<int> a = arma::conv_to< std::vector<int> >::from(arma::sort(x));
  std::vector<int> b = arma::conv_to< std::vector<int> >::from(arma::sort(y));
  std::vector<int> out;

  std::set_difference(a.begin(), a.end(), b.begin(), b.end(),
                      std::inserter(out, out.end()));

  return arma::conv_to<arma::uvec>::from(out);
}

答案 2 :(得分:0)

发问者可能已经找到答案。但是,以下模板版本可能更通用。这等效于Matlab中的setdiff函数

如果P和Q是两组,则它们的差由P-Q或Q-P给出。如果P = {1, 2, 3, 4}Q = {4, 5, 6},则P-Q表示P中不在Q中的元素,即在上面的示例中,P-Q = {1,2,3}。

/* setdiff(t1, t2) is similar to setdiff() function in MATLAB. It removes the common elements and
   gives the uncommon elements in the vectors t1 and t2. */


template <typename T>
T setdiff(T t1, T t2)
{
    int size_of_t1 = size(t1);
    int size_of_t2 = size(t2);

    T Intersection_Elements;
    uvec iA, iB;
    intersect(Intersection_Elements, iA, iB, t1, t2);

    for (int i = 0; i < size(iA); i++)
    {
        t1(iA(i)) = 0;
    }

    for (int i = 0; i < size(iB); i++)
    {
        t2(iB(i)) = 0;
    }

    T t1_t2_vec(size_of_t1 + size_of_t2);
    t1_t2_vec = join_vert(t1, t2);
    T DiffVec = nonzeros(t1_t2_vec);


    return DiffVec;
}

欢迎提出任何改进算法性能的建议。