RcppArmadillo LogicalMatrix运营

时间:2016-10-26 22:14:11

标签: r rcpp

我想在RcppArmadillo中做一些逻辑矩阵乘法,但是我遇到了一些问题。例如,在R中,可以在以下代码中执行此操作:

times = c(1,2,3)
ti = c(times,4)

lst = c(4,5,6)


st = matrix(lst,nrow=1) %*% outer(times,ti,"<")

结果:

> st
     [,1] [,2] [,3] [,4]
[1,]    0    4    9   15

此处matrix(lst,nrow=1)是1 x 3矩阵,outer(times,ti,"<")是3 x 4逻辑矩阵:

> matrix(lst,nrow=1)
     [,1] [,2] [,3]
[1,]    4    5    6 

> outer(times,ti,"<")
      [,1]  [,2]  [,3] [,4]
[1,] FALSE  TRUE  TRUE TRUE
[2,] FALSE FALSE  TRUE TRUE
[3,] FALSE FALSE FALSE TRUE

RcppArmadillo版本如下:

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace Rcpp;

// [[Rcpp::export(".vm")]]
arma::mat vm_mult(const arma::vec lhs,
                  const arma::umat rhs)
{
  return lhs.t() * rhs;
}

// [[Rcpp::export]]
NumericMatrix ty(NumericVector times, NumericVector ti,NumericVector lst){

  LogicalMatrix m = outer(times,ti,std::less<double>());
  NumericMatrix st = vm_mult(lst,m);

  return st;

}

vm_mult是向量矩阵乘法,我将矩阵定义为umat类型,即Mat<unsigned int>。尝试通过sourceCpp运行时出现以下错误:

error: conversion from 'LogicalMatrix' (aka 'Matrix<10>') to 'arma::umat' (aka 'Mat<unsigned int>') is ambiguous
  NumericMatrix st = vm_mult(mag,m);
                                 ^

我还将类型更改为const arma::Mat<unsigned char> rhs,并出现类似错误:

error: conversion from 'LogicalMatrix' (aka 'Matrix<10>') to 'arma::Mat<unsigned char>' is ambiguous
  NumericMatrix st = vm_mult(mag,m);
                                 ^

我检查了Armadillo库的文档,似乎没有专门定义的逻辑矩阵。

那么除了将逻辑矩阵转换为1,0整数矩阵外,我该怎么做。

1 个答案:

答案 0 :(得分:2)

好的,我明白了!它需要使用as<arma::umat>将LogicalMatrix从Rcpp传递到arma :: umat。

以下代码应该可以正常工作。

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace Rcpp;



// [[Rcpp::export]]
arma::mat ty(NumericVector times, NumericVector ti,NumericVector mag){

  LogicalMatrix m = outer(times,ti,std::less<double>());
  arma::umat rhs = as<arma::umat>(m);
  arma::vec lhs = as<arma::vec>(mag);
  arma::mat st = lhs.t() * rhs;

  return st;

}

结果:

> sourceCpp('vm.cpp')
> ty(times,ti,lst)
     [,1] [,2] [,3] [,4]
[1,]    0    4    9   15