我想在RcppArmadillo中做一些逻辑矩阵乘法,但是我遇到了一些问题。例如,在R中,可以在以下代码中执行此操作:
times = c(1,2,3)
ti = c(times,4)
lst = c(4,5,6)
st = matrix(lst,nrow=1) %*% outer(times,ti,"<")
结果:
> st
[,1] [,2] [,3] [,4]
[1,] 0 4 9 15
此处matrix(lst,nrow=1)
是1 x 3矩阵,outer(times,ti,"<")
是3 x 4逻辑矩阵:
> matrix(lst,nrow=1)
[,1] [,2] [,3]
[1,] 4 5 6
> outer(times,ti,"<")
[,1] [,2] [,3] [,4]
[1,] FALSE TRUE TRUE TRUE
[2,] FALSE FALSE TRUE TRUE
[3,] FALSE FALSE FALSE TRUE
RcppArmadillo版本如下:
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace Rcpp;
// [[Rcpp::export(".vm")]]
arma::mat vm_mult(const arma::vec lhs,
const arma::umat rhs)
{
return lhs.t() * rhs;
}
// [[Rcpp::export]]
NumericMatrix ty(NumericVector times, NumericVector ti,NumericVector lst){
LogicalMatrix m = outer(times,ti,std::less<double>());
NumericMatrix st = vm_mult(lst,m);
return st;
}
vm_mult
是向量矩阵乘法,我将矩阵定义为umat
类型,即Mat<unsigned int>
。尝试通过sourceCpp运行时出现以下错误:
error: conversion from 'LogicalMatrix' (aka 'Matrix<10>') to 'arma::umat' (aka 'Mat<unsigned int>') is ambiguous
NumericMatrix st = vm_mult(mag,m);
^
我还将类型更改为const arma::Mat<unsigned char> rhs
,并出现类似错误:
error: conversion from 'LogicalMatrix' (aka 'Matrix<10>') to 'arma::Mat<unsigned char>' is ambiguous
NumericMatrix st = vm_mult(mag,m);
^
我检查了Armadillo库的文档,似乎没有专门定义的逻辑矩阵。
那么除了将逻辑矩阵转换为1,0整数矩阵外,我该怎么做。
答案 0 :(得分:2)
好的,我明白了!它需要使用as<arma::umat>
将LogicalMatrix从Rcpp传递到arma :: umat。
以下代码应该可以正常工作。
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace Rcpp;
// [[Rcpp::export]]
arma::mat ty(NumericVector times, NumericVector ti,NumericVector mag){
LogicalMatrix m = outer(times,ti,std::less<double>());
arma::umat rhs = as<arma::umat>(m);
arma::vec lhs = as<arma::vec>(mag);
arma::mat st = lhs.t() * rhs;
return st;
}
结果:
> sourceCpp('vm.cpp')
> ty(times,ti,lst)
[,1] [,2] [,3] [,4]
[1,] 0 4 9 15