Rcpp中的向量均值

时间:2017-01-29 14:56:26

标签: r rcpp

我正在尝试将r函数转换为Rcpp以尝试加速,因为它涉及for循环。一路上我需要计算一个向量条目的平均值,它在R中就像mean(x)一样简单,但它似乎在Rcpp中不起作用,每次都给我0 0作为结果。

我的代码如下所示:

cppFunction(
  "NumericVector fun(int n, double lambda, ...) {
   ...
   NumericVector y = rpois(n, lambda);
   NumericVector w = dpois(y, lambda);
   NumericVector x = w*y;
   double z = mean(x);
   return z;
}")

编辑:所以我认为我的错误是由于上面提到的,并且单个双z的返回只是我试图隔离问题。但是,以下代码仍无效:

cppFunction(
   "NumericVector zstat(int n, double lambda, double lambda0, int m) {
   NumericVector z(m);
   for (int i=1; i<m; ++i){
   NumericVector y = rpois(n, lambda0);
   NumericVector w = dpois(y, lambda)/dpois(y,lambda0);
   double x = mean(w*y);
   z[i] = (x-2)/(sqrt(2/n));
   }
   return z;
}")

1 个答案:

答案 0 :(得分:7)

您的函数的返回类型为NumericVector,但Rcpp::mean返回可转换为double的标量值。解决此问题将解决问题:

library(Rcpp)

cppFunction(
  "double fun(int n, double lambda) {
   NumericVector y = rpois(n, lambda);
   NumericVector w = dpois(y, lambda);
   NumericVector x = w*y;
   double z = mean(x);
   return z;
}")

set.seed(123)
fun(50, 1.5)
# [1] 0.2992908

代码中发生了什么,因为NumericVector被指定为返回类型this constructor is called

template <typename T>
Vector(T size, 
    typename Rcpp::traits::enable_if<traits::is_arithmetic<T>::value, void>::type* = 0) {
    Storage::set__( Rf_allocVector( RTYPE, size) ) ;
    init() ;
}

会将double转换为整数类型,并创建一个NumericVector,其长度等于double的截断值。为了演示,

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
NumericVector from_double(double x) {
    return x;
}

/*** R

sapply(0.5:4.5, from_double)
# [[1]]
# numeric(0)
#
# [[2]]
# [1] 0
#
# [[3]]
# [1] 0 0
#
# [[4]]
# [1] 0 0 0
#
# [[5]]
# [1] 0 0 0 0

*/

修改:关于第二个问题,您要除以sqrt(2 / n),其中2n都是整数,最终导致除以在大多数情况下为零 - 因此结果向量中的所有Inf值。您可以使用2.0代替2来解决此问题:

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
NumericVector zstat(int n, double lambda, double lambda0, int m) {
    NumericVector z(m);
    for (int i=1; i<m; ++i){
        NumericVector y = rpois(n, lambda0);
        NumericVector w = dpois(y, lambda)/dpois(y,lambda0);
        double x = mean(w * y);
        // z[i] = (x - 2) / sqrt(2 / n);
        //                       ^^^^^
        z[i] = (x - 2) / sqrt(2.0 / n);
        //                    ^^^^^^^
   }
   return z;
}

/*** R

set.seed(123)
zstat(25, 2, 3, 10)
# [1]  0.0000000 -0.4427721  0.3199805  0.1016661  0.4078687  0.4054078
# [7] -0.1591861  0.9717596  0.6325110  0.1269779

*/

C ++不是R - 您需要更加小心变量的类型。