我尝试使用Rcpp重写一些R代码。但是,我发现性能有所下降。我针对的是我的代码中有问题的特定部分。在这一部分中,我从R中的optimise
包中导入stats
函数。
我要重写的R代码是:
###################################
# R implementation
phi_R <- function(x, mean = 0, beta) {
return(2*(beta^2)*((x-mean)^6) - 3*beta*((x-mean)^2))
}
bound_phi_R <- function(beta, mean = 0, lower, upper) {
# finding maxima and minimma in the interval
maxim <- optimise(function(x) phi_R(x, mean, beta), interval = c(lower, upper),
maximum = TRUE)$objective
minim <- optimise(function(x) phi_R(x, mean, beta), interval = c(lower, upper),
maximum = FALSE)$objective
# checking end points
at_lower <- phi_R(lower, mean, beta)
at_upper <- phi_R(upper, mean, beta)
# obtaining upper and lower bounds
upper_bound <- max(maxim, at_lower, at_upper)
lower_bound <- min(minim, at_lower, at_upper)
return(list('low_bound' = lower_bound, 'up_bound' = upper_bound))
}
此函数尝试查找称为phi的特定一维函数的上限和下限。 我的Rcpp实现是:
#include <Rcpp.h>
using namespace Rcpp;
// [[Rcpp::plugins("cpp17")]]
// [[Rcpp::depends(stats)]]
double phi_rcpp(const double &x,
const double &mean,
const double &beta) {
return ((2*beta*beta*pow(x-mean, 6))-(3*beta*(x-mean)*(x-mean)));
}
// [[Rcpp::export]]
Rcpp::List bound_phi_rcpp(const double &mean,
const double &beta,
const double &lower,
const double &upper) {
// Obtaining namespace of stats package in R
Rcpp::Environment stats("package:stats");
// Picking up optimise function
Function optimise = stats["optimise"];
// using optimise to find the maximum and minimum of phi within the interval
Rcpp::List maxim = optimise(_["f"] = Rcpp::InternalFunction(&phi_rcpp),
_["lower"] = lower,
_["upper"] = upper,
_["maximum"] = true,
_["mean"] = mean,
_["beta"] = beta);
Rcpp::List minim = optimise(_["f"] = Rcpp::InternalFunction(&phi_rcpp),
_["lower"] = lower,
_["upper"] = upper,
_["maximum"] = false,
_["mean"] = mean,
_["beta"] = beta);
// check the end points are not greater or less than the minimum and maximums from optimise
double at_upper = phi_rcpp(upper, mean, beta);
double at_lower = phi_rcpp(lower, mean, beta);
double upper_bound = std::max(as<double>(maxim[1]), std::max(at_lower, at_upper));
double lower_bound = std::min(as<double>(minim[1]), std::min(at_lower, at_upper));
// return bounds as vector
return Rcpp::List::create(Named("low_bound") = lower_bound,
Named("up_bound") = upper_bound);
}
接下来,我做一些基准测试:
library(Rcpp)
sourceCpp(file = 'rcpp.cpp')
pcm <- proc.time()
set.seed(42)
for (i in 1:10000) {
limits <- runif(2, -2, 2)
bound_phi_rcpp(beta = 1/4, mean = 0, lower = min(limits), upper = max(limits))
}
test1_time <- proc.time()-pcm
pcm <- proc.time()
set.seed(42)
for (i in 1:10000) {
limits <- runif(2, -2, 2)
bound_phi_R(beta = 1/4, mean = 0, lower = min(limits), upper = max(limits))
}
test2time <- proc.time()-pcm
print(paste('rcpp:', test1_time['elapsed'])) # 5.69 on my machine
print(paste('R:', test2_time['elapsed'])) # 0.0749 on my machine
# benchmarking with rbenchmark
set.seed(42)
limits <- runif(2, -2, 2)
identical(bound_phi_rcpp(beta = 1/4, mean = 0, lower = min(limits), upper = max(limits)),
bound_phi_R(beta = 1/4, mean = 0, lower = min(limits), upper = max(limits)))
rbenchmark::benchmark(cpp = bound_phi_rcpp(beta = 1/4, mean = 0, lower = min(limits), upper = max(limits)),
R = bound_phi_R(beta = 1/4, mean = 0, lower = min(limits), upper = max(limits)),
replications = 1000)
我得到以下基准:
test replications elapsed relative user.self sys.self user.child sys.child
1 cpp 1000 0.532 10.231 0.532 0.001 0 0
2 R 1000 0.052 1.000 0.052 0.000 0 0
从统计信息导入函数似乎有很多开销。是否有任何方法可以加快此过程,或者Rcpp中是否具有等效的优化功能?
答案 0 :(得分:3)
您的C ++代码很慢也就不足为奇了,因为您经常返回并强制在R和C ++之间进行操作。每个这样的过渡都有其成本。但是,可以使用仅在C ++中实现的优化算法,例如https://www.boost.org/doc/libs/1_70_0/libs/math/doc/html/math_toolkit/brent_minima.html似乎与R使用的算法相同,并且包含在BH
包中。事实证明,它也很容易使用:
#include <Rcpp.h>
// [[Rcpp::plugins(cpp11)]]
// [[Rcpp::depends(BH)]]
#include <boost/math/tools/minima.hpp>
class phi_rcpp {
private:
double mean;
double beta;
public:
phi_rcpp(double _mean, double _beta) : mean(_mean), beta(_beta) {}
double operator()(const double &x) {
double y = x - mean;
return (2*beta*beta*pow(y, 6))-(3*beta*y*y);
}
};
template<class T>
class negate : public T {
public:
using T::T;
double operator() (const double &x) {
return - T::operator()(x);
}
};
// [[Rcpp::export]]
Rcpp::List bound_phi_rcpp(const double &mean,
const double &beta,
const double &lower,
const double &upper) {
using boost::math::tools::brent_find_minima;
const int double_bits = std::numeric_limits<double>::digits;
phi_rcpp func(mean, beta);
negate<phi_rcpp> nfunc(mean, beta);
std::pair<double, double> min = brent_find_minima(func, lower, upper, double_bits);
std::pair<double, double> max = brent_find_minima(nfunc, lower, upper, double_bits);
double at_upper = func(upper);
double at_lower = func(lower);
return Rcpp::List::create(Rcpp::Named("low_bound") = std::min(min.second, std::min(at_upper, at_lower)),
Rcpp::Named("up_bound") = std::max(max.second, std::max(at_upper, at_lower)));
}
/*** R
phi_R <- function(x, mean = 0, beta) {
return(2*(beta^2)*((x-mean)^6) - 3*beta*((x-mean)^2))
}
bound_phi_R <- function(beta, mean = 0, lower, upper) {
# finding maxima and minimma in the interval
maxim <- optimise(function(x) phi_R(x, mean, beta), interval = c(lower, upper),
maximum = TRUE)$objective
minim <- optimise(function(x) phi_R(x, mean, beta), interval = c(lower, upper),
maximum = FALSE)$objective
# checking end points
at_lower <- phi_R(lower, mean, beta)
at_upper <- phi_R(upper, mean, beta)
# obtaining upper and lower bounds
upper_bound <- max(maxim, at_lower, at_upper)
lower_bound <- min(minim, at_lower, at_upper)
return(list('low_bound' = lower_bound, 'up_bound' = upper_bound))
}
set.seed(42)
limits <- runif(2, -2, 2)
bench::mark(cpp = bound_phi_rcpp(beta = 1/4, mean = 0, lower = min(limits), upper = max(limits)),
R = bound_phi_R(beta = 1/4, mean = 0, lower = min(limits), upper = max(limits)))
*/
这里唯一棘手的事情是求反函子的模板。基准测试结果:
# A tibble: 2 x 13
expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result memory time
<bch:expr> <bch:t> <bch:t> <dbl> <bch:byt> <dbl> <int> <dbl> <bch:tm> <list> <list> <lis>
1 cpp 6.26µs 7.94µs 117496. 2.49KB 11.8 9999 1 85.1ms <list… <Rpro… <bch…
2 R 61.51µs 72.31µs 11279. 124.98KB 11.1 5102 5 452.4ms <list… <Rpro… <bch…
# … with 1 more variable: gc <list>
请注意,bench::mark
默认检查相同的结果。