将Rcpp函数扩展为任何类型的输入向量

时间:2017-10-01 17:35:10

标签: c++ r rcpp

我有以下函数,它在NumericVector上执行一个简单循环并返回int类型值。

  Rcpp::cppFunction({'
  int calc_streak( NumericVector x, int i1, int i2){
  int cur_streak=1;

  if (NumericVector::is_na(x[0])){
    cur_streak = NumericVector::get_na();
  } else {
    cur_streak = 1;
  }

  for(int j = i1; j <= i2 ; ++j) {
    if( x[ j ] == x[ j-1 ]){
      cur_streak += 1;

    } else if(NumericVector::is_na( x[ j ] )){
      cur_streak = NumericVector::get_na();

    } else {
      cur_streak = 1;

    }
  }
  return cur_streak;
}
"})

calc_streak(c(1,1,1,1),i1=0,i2=3)
# [1] 4

功能对我来说很好,但真正的问题是我试图在其他输入类型上扩展此功能。我一直在搜索堆栈herehere,但这些示例在我的情况下不起作用,或者我不知道如何正确使用示例。我尝试过几种处理未知输入类型的方法,但在我的情况下没有一种方法成功。 以下三个例子

  1. this启发的最简单的一个 - 创建了main函数,该函数根据参数TYPEOF(x)的类型运行以前定义的函数之一。 此函数返回integernumeric的预期值。对于character会话崩溃

    Rcpp::cppFunction('
    #include <Rcpp.h>
    using namespace Rcpp;
    
    int streak_run_int(IntegerVector x, int i1, int i2){
      int cur_streak=1;
    
      if (IntegerVector::is_na(x[0])){
        cur_streak = NumericVector::get_na();
      } else {
       cur_streak = 1;
      }
    
      for(int j = i1; j <= i2 ; ++j) {
        if( x[ j ] == x[ j-1 ]){
         cur_streak += 1;
    
        } else if(IntegerVector::is_na( x[ j ] )){
          cur_streak = NumericVector::get_na();
    
        } else {
          cur_streak = 1;
    
        }
      }
      return cur_streak;
    }
    
    int streak_run_char(CharacterVector x, int i1, int i2){
      int cur_streak=1;
    
      if (CharacterVector::is_na(x[0])){
        cur_streak = NumericVector::get_na();
      } else {
        cur_streak = 1;
      }
    
      for(int j = i1; j <= i2 ; ++j) {
        if( x[ j ] == x[ j-1 ]){
        cur_streak += 1;
    
        } else if(CharacterVector::is_na( x[ j ] )){
          cur_streak = NumericVector::get_na();
    
        } else {
          cur_streak = 1;
    
        }
      }
      return cur_streak;
    }
    
    
      // [[Rcpp::export]]
    int streak_run4(SEXP x, int i1, int i2) {
      switch (TYPEOF(x)) {
      case INTSXP: {
        return streak_run_int(as<IntegerVector>(x), i1, i2);
      }
      case STRSXP: {
        return streak_run_char(as<CharacterVector>(x), i1, i2);
      }
      default: { return 0; }
      }
    }
    ')
    
    # expected results for int and real - for character session crashes
    streak_run4( c(1,1,1,1),i1=0, i2=3)
    streak_run4( as.integer(c(1,1,1,1)),i1=0, i2=3)
    streak_run4( as.character(c(1,1,1,1)),i1=0, i2=3) 
    
    1. 第二个功能有完全相同的想法,但使用模板而不是定义多个功能。与上述结果相同 - 会话崩溃character输入

      Rcpp::cppFunction('
      #include <Rcpp.h>
      using namespace Rcpp;
      
      namespace impl {
      
        template <int RTYPE>
          int streak_run_impl(const Vector<RTYPE>& x, int i1, int i2)
        {
          int cur_streak=1;
      
          if (Vector<RTYPE>::is_na(x[0])){
            cur_streak = NumericVector::get_na();
          } else {
            cur_streak = 1;
          }
      
          for(int j = i1; j <= i2 ; ++j) {
            if( x[ j ] == x[ j-1 ]){
              cur_streak += 1;
      
            } else if(Vector<RTYPE>::is_na( x[ j ] )){
              cur_streak = NumericVector::get_na();
      
            } else {
              cur_streak = 1;
      
            }
          }
          return cur_streak;
          }
      
      }
      
      // [[Rcpp::export]]
      int streak_run3(SEXP x, int i1, int i2) {
        switch (TYPEOF(x)) {
        case INTSXP: {
          return impl::streak_run_impl(as<IntegerVector>(x), i1, i2);
        }
        case REALSXP: {
          return impl::streak_run_impl(as<NumericVector>(x), i1, i2);
        }
        case STRSXP: {
          return impl::streak_run_impl(as<CharacterVector>(x), i1, i2);
        }
        case LGLSXP: {
          return impl::streak_run_impl(as<LogicalVector>(x), i1, i2);
        }
        case CPLXSXP: {
          return impl::streak_run_impl(as<ComplexVector>(x), i1, i2);
        }
        default: {
          return 0;
        }
        }
      }
      ')
      
      streak_run3( c(1,1,1,1),i1=0, i2=3)
      streak_run3( as.integer(c(1,1,1,1)),i1=0, i2=3)
      streak_run3( as.character(c(1,1,1,1)),i1=0, i2=3)
      
      1. 另一个受this article启发,这次我甚至无法编译C ++函数,同时出现错误use of overloaded operator '==' is ambiguous。无论如何,在检查了上面的两个例子之后,我不期待任何其他结果。

        Rcpp::cppFunction('
        #include <Rcpp.h>
        using namespace Rcpp;
        
        class streak_run2_impl {
          private:
          int i1;
          int i2;
        
          public:
          streak_run2_impl(int i1, int i2) : i1(i1), i2(i2) {}
        
          template <int RTYPE>
          IntegerVector operator()(const Vector<RTYPE>& x)
          {
        
            int cur_streak=1;
        
            if (Vector<RTYPE>::is_na(x[0])){
            cur_streak = NumericVector::get_na();
            } else {
            cur_streak = 1;
            }
        
            for(int j = i1; j <= i2 ; ++j) {
              if( x[ j ] == x[ j-1 ] ){
                cur_streak += 1;
        
              } else if(Vector<RTYPE>::is_na( x[ j ] )){
        
                cur_streak = NumericVector::get_na();
        
              } else {
                cur_streak = 1;
              }
            }
            return cur_streak;
          }
        };
        
        
        // [[Rcpp::export]]
        RObject streak_run2(RObject x, int i1 = 0, int i2=6){
          RCPP_RETURN_VECTOR(streak_run2_impl(i1, i2), x);
        }
        ')
        
      2. 所以我的问题是: 如何正确定义该函数以获得任何R类的输入向量的结果?
        我将不得不提出任何建议。

2 个答案:

答案 0 :(得分:3)

首先,好帖子!遗憾的是,由于完全不同的错误未在原型函数中获取,因此您链接的上述资源都不会与您的问题相关。对于为什么原型在调用时返回了一个有效的值,这是纯粹的运气。

正如@BenjaminChristoffersen指出的那样,由于发生undefined behavior (UB)错误,代码遇到out-of-bounds (OOB)。他的解决方案将有效地解决问题。

但是,要在将来自行诊断,请从使用元素访问器[]切换到(),它会检查您请求的元素是否在边界内。例如 0 0 n - 1 j 吗?

e.g。

  if (Vector<RTYPE>::is_na( x( 0 ) )){
  // ------------------------^---^

    cur_streak = NumericVector::get_na();
  } else {
    cur_streak = 1;
  }

  for(int j = i1; j <= i2 ; ++j) {
    if( x( j ) == x( j-1 )){
      // ^---^-----^-----^
      cur_streak += 1;

    } else if(Vector<RTYPE>::is_na( x( j ) )){
      // --------------------------- ^   ^

      cur_streak = NumericVector::get_na();

    } else {
      cur_streak = 1;

    }
  }

运行相同的命令,然后给出:

streak_run3( c(1,1,1,1),i1=0, i2=3)

输出:

Error in streak_run3(c(1, 1, 1, 1), i1 = 0, i2 = 3) : 
  Index out of bounds: [index=-1; extent=4].

输入:

streak_run3( as.integer(c(1,1,1,1)),i1=0, i2=3)

输出:

Error in streak_run3(as.integer(c(1, 1, 1, 1)), i1 = 0, i2 = 3) : 
  Index out of bounds: [index=-1; extent=4].

输入:

streak_run3( as.character(c(1,1,1,1)),i1=0, i2=3)

输出:

Error in streak_run3(as.character(c(1, 1, 1, 1)), i1 = 0, i2 = 3) : 
  Index out of bounds: [index=-1; extent=4].

答案 1 :(得分:2)

我认为示例中的主要错误是您在j = 0开始循环,因此您致电operator[](-1)。以下适用于我。进行以下func.cpp

#include <Rcpp.h>
#include <algorithm>
using namespace Rcpp;

template <int RTYPE>
int streak_run_impl(const Vector<RTYPE>& x, int i1, int i2)
{
  int cur_streak = 1;

  if (Vector<RTYPE>::is_na(x[0])){
    cur_streak = NA_INTEGER;
  } else {
    cur_streak = 1;
  }

  for(int j = std::max(i1, 1) /* have to start at one at least */; 
      j < std::min(i2 + 1, (int)x.size()) /* check size of x */; ++j){
    if(x[j] == x[j - 1]){
      cur_streak += 1;

    } else if(Vector<RTYPE>::is_na(x[j])){
      cur_streak = NA_INTEGER;

    } else {
      cur_streak = 1;

    }
  }
  return cur_streak;
}

// [[Rcpp::export]]
int streak_run3(SEXP x, int i1, int i2) {
  switch (TYPEOF(x)) {
    case INTSXP: {
      return streak_run_impl(as<IntegerVector>(x), i1, i2);
    }
    case REALSXP: {
      return streak_run_impl(as<NumericVector>(x), i1, i2);
    }
    case STRSXP: {
      return streak_run_impl(as<CharacterVector>(x), i1, i2);
    }
    case LGLSXP: {
      return streak_run_impl(as<LogicalVector>(x), i1, i2);
    }
    case CPLXSXP: {
      return streak_run_impl(as<ComplexVector>(x), i1, i2);
    }
    default: {
      return 0;
    }
  }
}

然后运行此R脚本,并将工作目录设置为.cpp文件

的工作目录
Rcpp::sourceCpp("func.cpp")

streak_run3(c(1,1,1,1), i1=0, i2=3)
streak_run3(as.integer(c(1,1,1,1)), i1=0, i2=3)
streak_run3(as.character(c(1,1,1,1)), i1=0, i2=3)