dmvnorm MVN密度 - RcppArmadillo实现比R包慢,包括一点Fortran

时间:2013-07-12 14:43:46

标签: r rcpp

解决方案现已在 Rcpp Gallery

中联机

我从RcppArmadillo的mvtnorm包中重新实现了dmvnorm。我不知何故喜欢犰狳,但我想它也适用于普通的Rcpp。来自dmvnorm的方法基于马哈拉诺比斯距离,所以我有一个函数,然后是多元正态密度函数。

让我告诉你我的代码:

#include <RcppArmadillo.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppArmadillo")]]

// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x ,  arma::mat mu, arma::mat sigma ){

  int n = x.n_rows;
  arma::vec md(n);
    for (int i=0; i<n; i++){
        arma::mat x_i = x.row(i) - mu;
        arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
        md(i) = arma::as_scalar(x_i * Y);
    }
    return md;

    }



// [[Rcpp::export]]
arma::vec dmvnorm ( arma::mat x,  arma::mat mean,  arma::mat sigma, bool log){ 

arma::vec distval = mahalanobis_arma(x,  mean, sigma);

    double logdet = sum(arma::log(arma::eig_sym(sigma)));
    double log2pi = 1.8378770664093454835606594728112352797227949472755668;
    arma::vec logretval = -( (x.n_cols * log2pi + logdet + distval)/2  ) ;

       if(log){ 
         return(logretval);

       }else { 
       return(exp(logretval));
         }
}

所以,而不是我的失望:

模拟一些数据

sigma <- matrix(c(4,2,2,3), ncol=2)
x <- rmvnorm(n=5000000, mean=c(1,2), sigma=sigma, method="chol")

和基准

system.time(mvtnorm::dmvnorm(x,t(1:2),.2+diag(2),F))
   user  system elapsed 
   0.05    0.02    0.06 

system.time(dmvnorm(x,t(1:2),.2+diag(2),F))
   user  system elapsed 
   0.12    0.02    0.14 

没有!!!!!! : - (

[编辑]

问题是: 1)为什么RcppArmadillo实现比普通R实现慢? 2)如何创建一个优于R实现的Rcpp / RcppArmadillo实现?

[编辑2]

我将mahalanobis_arma放入mvtnorm :: dmvnorm函数中,它也会慢下来。

1 个答案:

答案 0 :(得分:8)

如果你想更快地实现mahalanobis距离,你只需要重新编写算法并模仿R使用的算法。这很简单

我修改了你的函数mahalanobis_arma以将mu变为rowvec

基本上我只是将R代码翻译成RcppArmadillo

mahalanobis
function (x, center, cov, inverted = FALSE, ...) 
{
    x <- if (is.vector(x)) 
        matrix(x, ncol = length(x))
    else as.matrix(x)
    x <- sweep(x, 2, center)
    if (!inverted) 
        cov <- solve(cov, ...)
    setNames(rowSums((x %*% cov) * x), rownames(x))
}
<bytecode: 0x6e5b408>
<environment: namespace:stats>

这是

#include <RcppArmadillo.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::vec Mahalanobis(arma::mat x, arma::rowvec center, arma::mat cov){
    int n = x.n_rows;
    arma::mat x_cen;
    x_cen.copy_size(x);
    for (int i=0; i < n; i++) {
        x_cen.row(i) = x.row(i) - center;
    }
    return sum((x_cen * cov.i()) % x_cen, 1);    
}


// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x ,  arma::rowvec mu, arma::mat sigma ){

  int n = x.n_rows;
  arma::vec md(n);
    for (int i=0; i<n; i++){
        arma::mat x_i = x.row(i) - mu;
        arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
        md(i) = arma::as_scalar(x_i * Y);
    }
    return md;

    }

现在,让我们比较一下这个新的犰狳版本(Mahalanobis),您的第一个版本(mahalanobis_arma)和R实现版本(mahalanobis)。

我将此Cpp代码保存为mahalanobis.cpp

require(RcppArmadillo)
sourceCpp("mahalanobis.cpp")

set.seed(1)
x <- matrix(rnorm(10000 * 10), ncol = 10)
Sx <- cov(x)


all.equal(c(Mahalanobis(x, colMeans(x), Sx))
          ,mahalanobis(x, colMeans(x), Sx))
## [1] TRUE

all.equal(mahalanobis_arma(x, colMeans(x), Sx)
          ,Mahalanobis(x, colMeans(x), Sx))
## [1] TRUE


require(rbenchmark)
benchmark(Mahalanobis(x, colMeans(x), Sx),
          mahalanobis(x, colMeans(x), Sx),
          mahalanobis_arma(x, colMeans(x), Sx),
          order = "elapsed")


##                                   test replications elapsed
## 1      Mahalanobis(x, colMeans(x), Sx)          100   0.124
## 2      mahalanobis(x, colMeans(x), Sx)          100   0.741
## 3 mahalanobis_arma(x, colMeans(x), Sx)          100   4.509
##   relative user.self sys.self user.child sys.child
## 1    1.000     0.173    0.077          0         0
## 2    5.976     0.804    0.670          0         0
## 3   36.363     4.386    4.626          0         0

正如您所看到的,新实现比R更快。 我很确定我们可以通过使用cholesky分解来解决协方差矩阵或使用其他矩阵分解来做得更好。

最后,我们可以将此Mahalanobis函数插入您的dmvnorm并对其进行测试:

require(mvtnorm)
set.seed(1)
sigma <- matrix(c(4, 2, 2, 3), ncol = 2)
x <- rmvnorm(n = 5000000, mean = c(1, 2), sigma = sigma, method = "chol")


all.equal(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          c(dmvnorm(x, t(1:2), .2+diag(2), FALSE)))
## [1] TRUE

benchmark(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          dmvnorm(x, t(1:2), .2+diag(2), FALSE),
          order = "elapsed")

##                                                test replications
## 2          dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
## 1 mvtnorm::dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
##   elapsed relative user.self sys.self user.child sys.child
## 2  35.366    1.000    31.117    4.193          0         0
## 1  60.770    1.718    56.666   13.236          0         0

现在快几倍了。