有什么办法可以加快使用Rcpp进行功能编程的速度吗?

时间:2019-03-23 02:35:08

标签: r rcpp

我试图找出如何有效地将Rcpp合并到R中,并且作为测试用例,我选择了Bisection方法,这是一种流行的求根方法。它还需要进行少量的函数式编程,因为您需要调用要查找其根的函数才能找到根。可能是我错了,只是我的C ++ / Rcpp不好。

问题是Rcpp函数实际上需要花费(更长)的时间才能运行,我很确定,因为每次调用R函数时都会处理R函数。我在下面提供了一些基准测试,以及我的R和C ++版本的代码:

 microbenchmark::microbenchmark(bisectionMethod(function(x) x^3, -10, 7),
 bisectionMethodCPP(function(x) x^3, -10, 7), 
 uniroot(function(x) x^3, c(-10, 7)))

> Unit: microseconds
>                                     expr      min       lq      mean   median
> bisectionMethod(function(x) x^3, -10, 7)   54.521   55.904   69.9303   58.005
> bisectionMethodCPP(function(x) x^3, -10, 7) 1463.781 1496.900 1677.2274 1532.158
>     uniroot(function(x) x^3, c(-10, 7))   62.016   65.862  103.2841   72.320
>    uq      max neval cld
> 62.0835 1100.376   100  a 
> 1604.4215 4941.922   100   b
> 98.5615 2135.752   100  a 

Rcpp:

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]

NumericVector bisectionMethodCPP(Function f, NumericVector a, NumericVector b, 
                          NumericVector eps = 0.0000000001)
{
  if (is_true(all(b <= a))) {
    NumericVector switchVal = a;
    a = b;
    b = switchVal;
  }
  NumericVector fa = f(a);
  NumericVector fb = f(b);
  if (is_true(all(fa*fb > 0))) {
    return -999.0;
  }
  while (is_true(all((b-a) > eps))) {
    NumericVector middle = ((b+a)/2.0);
    NumericVector fmid = f(middle);
    if (is_true(all(fmid*fb > 0))) {
      b = middle;
    }
    else if (is_true(all(fmid*fa > 0))) {
      a = middle;
    }
    else if (is_true(all(fmid == 0.0))) {
      return middle;
    }
    fa = f(a);
    fb = f(b);
  }
  NumericVector Ans = ((b+a)/2.0);
  return Ans;
}

R版本:

bisectionMethod <- function(f, a, b, eps = 1e-10, iter.max = 100) {
  ## R version of the Bisection Method
  if (f(a)*f(b) > 0) {
    stop("Bisection method invalid for this kind of case.")
  }
  if (b <= a) {
    switchVal = a
    a = b
    b = switchVal
  }
  i <- 1
  while ((b-a) > eps) {
    middle = (b+a)/2.0
    if (f(middle) * f(b) > 0) {
      b = middle
    }
    else if (f(middle) * f(a) > 0) {
      a = middle
   }
   else if (f(middle) == 0.0) {
     return(middle)
   }
   i <- i+1
   if (iter.max == i) {
     stop("Maximum iterations reached without locating root.")
   }
 }
 return ((b+a)/2.0)
}

0 个答案:

没有答案