我有一个很长的参数向量(大约4 ^ 10个元素)和一个索引向量。我的目标是将索引向量中索引的所有参数值加在一起。
例如,如果我有para = [1,2,3,4,5,5,5]和indices = [3,3,1,6]那么我想找到第三个的累积和得到12的值(3)两次,第一个值(1)和第六个(5)。还有一个选项可以根据它们的位置扭曲参数值。
我正在尝试加速R实现,因为我称之为数百万次。
我当前的代码总是返回production
,我看不出它出错的地方
这是Rcpp功能:
NA
这是工作的R版本:
double dot_prod_c(NumericVector indices, NumericVector paras,
NumericVector warp = NA_REAL) {
int len = indices.size();
LogicalVector indices_ok;
for (int i = 0; i < len; i++){
indices_ok.push_back(R_IsNA(indices[i]));
}
if(is_true(any(indices_ok))){
return NA_REAL;
}
double counter = 0;
if(NumericVector::is_na(warp[1])){
for (int i = 0; i < len; i++){
counter += paras[indices[i]];
}
} else {
for (int i = 0; i < len; i++){
counter += paras[indices[i]] * warp[i];
}
}
return counter;
}
以下是使用microbenchmark软件包进行测试和基准测试的一些代码:
dot_prod <- function(indices, paras, warp = NA){
if(is.na(warp[1])){
return(sum(sapply(indices, function(ind) paras[ind + 1])))
} else {
return(sum(sapply(1:length(indices), function(i){
ind <- indices[i]
paras[ind + 1] * warp[i]
})))
}
}
答案 0 :(得分:7)
我在尝试通过您的代码解释您的总体目标时遇到了一些麻烦,所以我只是想解决这个问题
例如,如果我有para = [1,2,3,4,5,5,5]并且index = [3,3,1,6] 然后我想找到第三个值的累积和(3) 两次,第一个值(1)和第六个(5),得到12.有 另外还可以选择根据参数值进行变形 他们的位置。
因为我最清楚。
您的C ++代码存在一些问题。首先,不要这样做 - NumericVector warp = NA_REAL
- 使用Rcpp::Nullable<>
模板(如下所示)。这将解决一些问题:
Nullable
类,它几乎就是它听起来的样子 - 一个可能为空的对象,也可能不为空。 NumericVector warp = NA_REAL
。坦率地说,我很惊讶编译器接受了这一点。 if(NumericVector::is_na(warp[1])){
。这有不明确的行为写在它上面。 这是一个修订版本,取决于您对上述问题的引用说明:
#include <Rcpp.h>
typedef Rcpp::Nullable<Rcpp::NumericVector> nullable_t;
// [[Rcpp::export]]
double DotProd(Rcpp::NumericVector indices, Rcpp::NumericVector params, nullable_t warp_ = R_NilValue) {
R_xlen_t i = 0, n = indices.size();
double result = 0.0;
if (warp_.isNull()) {
for ( ; i < n; i++) {
result += params[indices[i]];
}
} else {
Rcpp::NumericVector warp(warp_);
for ( ; i < n; i++) {
result += params[indices[i]] * warp[i];
}
}
return result;
}
您有一些精心设计的代码来生成示例数据。我没有花时间来完成这个,因为它不是必要的,也不是基准测试。你说自己C ++版本没有产生正确的结果。您的首要任务应该是让您的代码处理简单数据。然后给它提供一些更复杂的数据。然后基准。上面的修订版本适用于简单数据:
args <- list(
indices = c(3, 3, 1, 6),
params = c(1, 2, 3, 4, 5, 5, 5),
warp = c(.25, .75, 1.25, 1.75)
)
all.equal(
DotProd(args[[1]], args[[2]]),
dot_prod(args[[1]], args[[2]]))
#[1] TRUE
all.equal(
DotProd(args[[1]], args[[2]], args[[3]]),
dot_prod(args[[1]], args[[2]], args[[3]]))
#[1] TRUE
此示例数据的速度也比R版本快。我没有理由相信它不适用于更大,更复杂的数据 - 对于* apply函数来说,没有任何神奇或特别的效率;他们只是更惯用/可读R.
microbenchmark::microbenchmark(
"Rcpp" = DotProd(args[[1]], args[[2]]),
"R" = dot_prod(args[[1]], args[[2]]))
#Unit: microseconds
#expr min lq mean median uq max neval
#Rcpp 2.463 2.8815 3.52907 3.3265 3.8445 18.823 100
#R 18.869 20.0285 21.60490 20.4400 21.0745 66.531 100
#
microbenchmark::microbenchmark(
"Rcpp" = DotProd(args[[1]], args[[2]], args[[3]]),
"R" = dot_prod(args[[1]], args[[2]], args[[3]]))
#Unit: microseconds
#expr min lq mean median uq max neval
#Rcpp 2.680 3.0430 3.84796 3.701 4.1360 12.304 100
#R 21.587 22.6855 23.79194 23.342 23.8565 68.473 100
我省略了上面例子中的NA
检查,但是通过使用一点Rcpp糖也可以修改为更惯用的东西。以前,你这样做:
LogicalVector indices_ok;
for (int i = 0; i < len; i++){
indices_ok.push_back(R_IsNA(indices[i]));
}
if(is_true(any(indices_ok))){
return NA_REAL;
}
它有点咄咄逼人 - 您正在测试整个值向量(使用R_IsNA
),然后应用is_true(any(indices_ok))
- 当您可能过早地中断并返回NA_REAL
时在R_IsNA(indices[i])
的第一个实例上生成true
。此外,push_back
的使用会使您的功能相当慢 - 您最好将indices_ok
初始化为已知大小并通过循环中的索引访问来填充它。然而,这是压缩操作的一种方法:
if (Rcpp::na_omit(indices).size() != indices.size()) return NA_REAL;
为了完整性,这里有一个完全糖化的版本,可以让你完全避免循环:
#include <Rcpp.h>
typedef Rcpp::Nullable<Rcpp::NumericVector> nullable_t;
// [[Rcpp::export]]
double DotProd3(Rcpp::NumericVector indices, Rcpp::NumericVector params, nullable_t warp_ = R_NilValue) {
if (Rcpp::na_omit(indices).size() != indices.size()) return NA_REAL;
if (warp_.isNull()) {
Rcpp::NumericVector tmp = params[indices];
return Rcpp::sum(tmp);
} else {
Rcpp::NumericVector warp(warp_), tmp = params[indices];
return Rcpp::sum(tmp * warp);
}
}
/*** R
all.equal(
DotProd3(args[[1]], args[[2]]),
dot_prod(args[[1]], args[[2]]))
#[1] TRUE
all.equal(
DotProd3(args[[1]], args[[2]], args[[3]]),
dot_prod(args[[1]], args[[2]], args[[3]]))
#[1] TRUE
*/