在R中,矩阵乘法非常优化,即实际上只是对BLAS / LAPACK的调用。但是,令我惊讶的是,这种用于矩阵向量乘法的非常幼稚的C ++代码似乎可靠地快了30%。
library(Rcpp)
# Simple C++ code for matrix multiplication
mm_code =
"NumericVector my_mm(NumericMatrix m, NumericVector v){
int nRow = m.rows();
int nCol = m.cols();
NumericVector ans(nRow);
double v_j;
for(int j = 0; j < nCol; j++){
v_j = v[j];
for(int i = 0; i < nRow; i++){
ans[i] += m(i,j) * v_j;
}
}
return(ans);
}
"
# Compiling
my_mm = cppFunction(code = mm_code)
# Simulating data to use
nRow = 10^4
nCol = 10^4
m = matrix(rnorm(nRow * nCol), nrow = nRow)
v = rnorm(nCol)
system.time(my_ans <- my_mm(m, v))
#> user system elapsed
#> 0.103 0.001 0.103
system.time(r_ans <- m %*% v)
#> user system elapsed
#> 0.154 0.001 0.154
# Double checking answer is correct
max(abs(my_ans - r_ans))
#> [1] 0
基数R的%*%
是否执行某种我跳过的数据检查?
编辑:
了解了所发生的事情之后(谢谢!),值得注意的是,这对于R的%*%
来说是最坏的情况,即矢量矩阵。例如,@ RalfStubner指出,使用矩阵向量乘法的RcppArmadillo实现比我演示的朴素实现还要快,这意味着比基数R快得多,但实际上与基数R的%*%
相同。 -matrix乘法(当两个矩阵都大且为正方形时):
arma_code <-
"arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
return m * m2;
};"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")
nRow = 10^3
nCol = 10^3
mat1 = matrix(rnorm(nRow * nCol),
nrow = nRow)
mat2 = matrix(rnorm(nRow * nCol),
nrow = nRow)
system.time(arma_mm(mat1, mat2))
#> user system elapsed
#> 0.798 0.008 0.814
system.time(mat1 %*% mat2)
#> user system elapsed
#> 0.807 0.005 0.822
因此,R的当前(v3.5.0)%*%
对于矩阵矩阵而言接近最佳,但如果您可以跳过检查,则可以大大提高矩阵向量的速度。
答案 0 :(得分:27)
快速浏览names.c
(here in particular)会将您指向do_matprod
,这是%*%
所调用的C函数,它位于文件{{1 }}。 (有趣的是,事实证明,array.c
和crossprod
都分派给相同的函数)。 Here is a link到tcrossprod
的代码。
滚动浏览该功能,您会发现它可以处理您的天真的实现所不具备的许多功能,包括:
do_matprod
操作的两个对象属于已提供此类方法的类时,允许分派给其他S4方法。 (这就是函数this portion中发生的事情。) Near the end of the function,它将分派到matprod
或cmatprod
中。有趣的是(至少对我而言),对于实数矩阵,如果其中任一矩阵可能包含%*%
或NaN
值,则Inf
会调度({{ 3}})称为here的函数,它和您自己的函数一样简单明了。否则,它将分派给几个BLAS Fortran例程之一,如果可以保证统一的“行为良好”的矩阵元素,则该例程可能会更快。
答案 1 :(得分:7)
Josh的答案解释了为什么R的矩阵乘法没有这种幼稚的方法那么快。我很想知道使用RcppArmadillo可以获得多少收益。代码很简单:
arma_code <-
"arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
return m * v;
};"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")
基准:
> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
expr min lq mean median uq max neval
my_mm(m, v) 71.23347 75.22364 90.13766 96.88279 98.07348 98.50182 10
m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751 10
arma_mm(m, v) 41.13348 41.42314 41.89311 41.81979 42.39311 42.78396 10
因此RcppArmadillo为我们提供了更好的语法和更好的性能。
好奇心使我变得更好。这里是直接使用BLAS的解决方案:
blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
int nRow = m.rows();
int nCol = m.cols();
NumericVector ans(nRow);
char trans = 'N';
double one = 1.0, zero = 0.0;
int ione = 1;
F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
&ione, &zero, ans.begin(), &ione);
return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")
基准:
Unit: milliseconds
expr min lq mean median uq max neval
my_mm(m, v) 72.61298 75.40050 89.75529 96.04413 96.59283 98.29938 10
m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572 10
arma_mm(m, v) 41.06718 41.70331 42.62366 42.47320 43.22625 45.19704 10
blas_mm(m, v) 41.58618 42.14718 42.89853 42.68584 43.39182 44.46577 10
Armadillo和BLAS(在我的情况下为OpenBLAS)几乎相同。 R最终也要做BLAS代码。所以R的2/3是错误检查等。
答案 2 :(得分:2)
要为Ralf Stubner的解决方案增加一点,则可以使用以下C ++版本
__restrict__
可能允许向量运算(在这里可能并不重要,因为我猜只是读取)。#include <Rcpp.h>
using namespace Rcpp;
inline void mat_vec_mult_vanilla
(double const * __restrict__ m,
double const * __restrict__ v,
double * __restrict__ const res,
size_t const dn, size_t const dm) noexcept {
for(size_t j = 0; j < dm; ++j, ++v){
double * r = res;
for(size_t i = 0; i < dn; ++i, ++r, ++m)
*r += *m * *v;
}
}
inline void mat_vec_mult
(double const * __restrict__ const m,
double const * __restrict__ const v,
double * __restrict__ const res,
size_t const dn, size_t const dm) noexcept {
size_t j(0L);
double const * vj = v,
* mi = m;
constexpr size_t const ncl(8L);
{
double const * mvals[ncl];
size_t const end_j = dm - (dm % ncl),
inc = ncl * dn;
for(; j < end_j; j += ncl, vj += ncl, mi += inc){
double *r = res;
mvals[0] = mi;
for(size_t i = 1; i < ncl; ++i)
mvals[i] = mvals[i - 1L] + dn;
for(size_t i = 0; i < dn; ++i, ++r)
for(size_t ii = 0; ii < ncl; ++ii)
*r += *(vj + ii) * *mvals[ii]++;
}
}
mat_vec_mult_vanilla(mi, vj, res, dn, dm - j);
}
// [[Rcpp::export("mat_vec_mult", rng = false)]]
NumericVector mat_vec_mult_cpp(NumericMatrix m, NumericVector v){
size_t const dn = m.nrow(),
dm = m.ncol();
NumericVector res(dn);
mat_vec_mult(&m[0], &v[0], &res[0], dn, dm);
return res;
}
// [[Rcpp::export("mat_vec_mult_vanilla", rng = false)]]
NumericVector mat_vec_mult_vanilla_cpp(NumericMatrix m, NumericVector v){
size_t const dn = m.nrow(),
dm = m.ncol();
NumericVector res(dn);
mat_vec_mult_vanilla(&m[0], &v[0], &res[0], dn, dm);
return res;
}
在我的Makevars文件和gcc-8.3中带有-O3
的结果是
set.seed(1)
dn <- 10001L
dm <- 10001L
m <- matrix(rnorm(dn * dm), dn, dm)
lv <- rnorm(dm)
all.equal(drop(m %*% lv), mat_vec_mult(m = m, v = lv))
#R> [1] TRUE
all.equal(drop(m %*% lv), mat_vec_mult_vanilla(m = m, v = lv))
#R> [1] TRUE
bench::mark(
R = m %*% lv,
`OP's version` = my_mm(m = m, v = lv),
`BLAS` = blas_mm(m = m, v = lv),
`C++ vanilla` = mat_vec_mult_vanilla(m = m, v = lv),
`C++` = mat_vec_mult(m = m, v = lv), check = FALSE)
#R> # A tibble: 5 x 13
#R> expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result memory time gc
#R> <bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl> <int> <dbl> <bch:tm> <list> <list> <list> <list>
#R> 1 R 147.9ms 151ms 6.57 78.2KB 0 4 0 609ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [4]> <tibble [4 × 3]>
#R> 2 OP's version 56.9ms 57.1ms 17.4 78.2KB 0 9 0 516ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [9]> <tibble [9 × 3]>
#R> 3 BLAS 90.1ms 90.7ms 11.0 78.2KB 0 6 0 545ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [6]> <tibble [6 × 3]>
#R> 4 C++ vanilla 57.2ms 57.4ms 17.4 78.2KB 0 9 0 518ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [9]> <tibble [9 × 3]>
#R> 5 C++ 51ms 51.4ms 19.3 78.2KB 0 10 0 519ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [10]> <tibble [10 × 3]>
有一点改进。结果可能非常取决于BLAS版本。我使用的版本是
sessionInfo()
#R> #...
#R> Matrix products: default
#R> BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
#R> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
#R> ...
我Rcpp::sourceCpp()
编辑的整个文件是
#include <Rcpp.h>
#include <R_ext/BLAS.h>
using namespace Rcpp;
inline void mat_vec_mult_vanilla
(double const * __restrict__ m,
double const * __restrict__ v,
double * __restrict__ const res,
size_t const dn, size_t const dm) noexcept {
for(size_t j = 0; j < dm; ++j, ++v){
double * r = res;
for(size_t i = 0; i < dn; ++i, ++r, ++m)
*r += *m * *v;
}
}
inline void mat_vec_mult
(double const * __restrict__ const m,
double const * __restrict__ const v,
double * __restrict__ const res,
size_t const dn, size_t const dm) noexcept {
size_t j(0L);
double const * vj = v,
* mi = m;
constexpr size_t const ncl(8L);
{
double const * mvals[ncl];
size_t const end_j = dm - (dm % ncl),
inc = ncl * dn;
for(; j < end_j; j += ncl, vj += ncl, mi += inc){
double *r = res;
mvals[0] = mi;
for(size_t i = 1; i < ncl; ++i)
mvals[i] = mvals[i - 1L] + dn;
for(size_t i = 0; i < dn; ++i, ++r)
for(size_t ii = 0; ii < ncl; ++ii)
*r += *(vj + ii) * *mvals[ii]++;
}
}
mat_vec_mult_vanilla(mi, vj, res, dn, dm - j);
}
// [[Rcpp::export("mat_vec_mult", rng = false)]]
NumericVector mat_vec_mult_cpp(NumericMatrix m, NumericVector v){
size_t const dn = m.nrow(),
dm = m.ncol();
NumericVector res(dn);
mat_vec_mult(&m[0], &v[0], &res[0], dn, dm);
return res;
}
// [[Rcpp::export("mat_vec_mult_vanilla", rng = false)]]
NumericVector mat_vec_mult_vanilla_cpp(NumericMatrix m, NumericVector v){
size_t const dn = m.nrow(),
dm = m.ncol();
NumericVector res(dn);
mat_vec_mult_vanilla(&m[0], &v[0], &res[0], dn, dm);
return res;
}
// [[Rcpp::export(rng = false)]]
NumericVector my_mm(NumericMatrix m, NumericVector v){
int nRow = m.rows();
int nCol = m.cols();
NumericVector ans(nRow);
double v_j;
for(int j = 0; j < nCol; j++){
v_j = v[j];
for(int i = 0; i < nRow; i++){
ans[i] += m(i,j) * v_j;
}
}
return(ans);
}
// [[Rcpp::export(rng = false)]]
NumericVector blas_mm(NumericMatrix m, NumericVector v){
int nRow = m.rows();
int nCol = m.cols();
NumericVector ans(nRow);
char trans = 'N';
double one = 1.0, zero = 0.0;
int ione = 1;
F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
&ione, &zero, ans.begin(), &ione);
return ans;
}
/*** R
set.seed(1)
dn <- 10001L
dm <- 10001L
m <- matrix(rnorm(dn * dm), dn, dm)
lv <- rnorm(dm)
all.equal(drop(m %*% lv), mat_vec_mult(m = m, v = lv))
all.equal(drop(m %*% lv), mat_vec_mult_vanilla(m = m, v = lv))
bench::mark(
R = m %*% lv,
`OP's version` = my_mm(m = m, v = lv),
`BLAS` = blas_mm(m = m, v = lv),
`C++ vanilla` = mat_vec_mult_vanilla(m = m, v = lv),
`C++` = mat_vec_mult(m = m, v = lv), check = FALSE)
*/