R中的逐行矩阵乘法

时间:2018-02-20 07:10:37

标签: r

我有一个矩阵,其维度为1亿条记录和100列。

现在我想通过rowwise将该矩阵相乘。

我的矩阵乘法示例代码是

df<-as.matrix(mtcars)
result<-apply(df,1,prod)

在我的情况下,上面的语法非常慢。

我在 Rfast 包中尝试了 rowprods 功能。

result<-rowprods(mtcars)

但上述功能给了我空间问题。

注意:我的系统中有8 GB RAM。

4 个答案:

答案 0 :(得分:4)

如果您的矩阵太大而无法容纳在内存中,您可以使用包 bigstatsr (免责声明:我是作者)来使用存储在磁盘上的数据(而不是RAM) )。使用函数<android.support.design.widget.TabLayout android:id="@+id/tabs" android:layout_width="match_parent" android:layout_height="wrap_content" app:tabMaxWidth="0dp" app:tabGravity="fill" app:tabMode="fixed" /> 可以在数据块上应用标准R函数(并将它们组合在一起)。

big_apply

答案 1 :(得分:2)

使用data.table尝试包Reduce。这可能会避免1e10长度向量的内部副本。

library(data.table)
df <- data.table(df, keep.rownames=TRUE)
df[, rowprods:= Reduce("*", .SD), .SDcols = -1]
df[, .(rn, rowprods)]
#                     rn   rowprods
# 1:           Mazda RX4          0
# 2:       Mazda RX4 Wag          0
# 3:          Datsun 710  609055152
# 4:      Hornet 4 Drive          0
# 5:   Hornet Sportabout          0
# 6:             Valiant          0
# 7:          Duster 360          0
# 8:           Merc 240D          0
# 9:            Merc 230          0
#10:            Merc 280          0
#11:           Merc 280C          0
#12:          Merc 450SE          0
#13:          Merc 450SL          0
#14:         Merc 450SLC          0
#15:  Cadillac Fleetwood          0
#16: Lincoln Continental          0
#17:   Chrysler Imperial          0
#18:            Fiat 128  470578906
#19:         Honda Civic  564655046
#20:      Toyota Corolla  386281789
#21:       Toyota Corona          0
#22:    Dodge Challenger          0
#23:         AMC Javelin          0
#24:          Camaro Z28          0
#25:    Pontiac Firebird          0
#26:           Fiat X1-9  339825992
#27:       Porsche 914-2          0
#28:        Lotus Europa 1259677924
#29:      Ford Pantera L          0
#30:        Ferrari Dino          0
#31:       Maserati Bora          0
#32:          Volvo 142E 1919442833
#                     rn    rowsums

但是,如果要处理此大小的数据,8 GB RAM(减去操作系统和其他软件需要的内容)并不多。 R有时需要制作内部副本才能使用您的数据。

答案 2 :(得分:2)

一些参考时间

library(matrixStats)
library(inline)
library(data.table)
#devtools::install_github("privefl/bigstatsr")
library(bigstatsr)
library(RcppArmadillo)
library(microbenchmark)
set.seed(20L)
N <- 1e6
dat <- matrix(rnorm(N*100),ncol=100)

fbm <- FBM(N, 100)
big_apply(fbm, a.FUN = function(X, ind) {
    print(min(ind))
    X[, ind] <- rnorm(nrow(X) * length(ind))
    NULL
}, a.combine = 'c')   

bigstatsrMtd <- function() {
    prods <- big_apply(fbm, a.FUN = function(X, ind) {
        print(min(ind))
        matrixStats::rowProds(X[ind, ])
    }, a.combine = 'c', ind = rows_along(fbm),
        block.size = 100e3, ncores = nb_cores())  
}

df <- data.table(as.data.frame(dat), keep.rownames=TRUE)
data.tableMtd <- function() {
    df[, rowprods:= Reduce("*", .SD), .SDcols = -1]
    df[, .(rn, rowprods)]    
}

code <- '
  arma::mat prodDat = Rcpp::as<arma::mat>(dat);
  int m = prodDat.n_rows;
  int n = prodDat.n_cols;
  arma::vec res(m);
  for (int row=0; row < m; row++) {
    res(row) = 1.0;
    for (int col=0; col < n; col++) {
      res(row) *= prodDat(row, col);
    }
  }
  return Rcpp::wrap(res);
'
rcppProd <- cxxfunction(signature(dat="numeric"), code, plugin="RcppArmadillo")

rcppMtd <- function() {
    rcppData <- rcppProd(dat)                # generated by C++ code
}

baseMtd <- function() {
    apply(dat, 1, prod)   
}

microbenchmark(bigstatsrMtd(),
    data.tableMtd(),
    rcppMtd(),
    baseMtd(),
    times=5L
)

注意:在cxxfunction中编译函数似乎需要一些时间

以下是时间安排结果:

# Unit: milliseconds
#            expr       min        lq      mean    median        uq       max
#  bigstatsrMtd() 4519.1861 4993.0879 5296.7000 5126.2282 5504.3981 6340.5995
# data.tableMtd()  443.1946  444.9686  690.3703  493.2399  513.4787 1556.9695
#       rcppMtd()  787.9488  799.1575  828.3647  809.0645  871.0347  874.6178
#       baseMtd() 5658.1424 6208.5123 6232.0040 6331.7431 6458.6806 6502.9417

答案 3 :(得分:1)

Rfast命令“行探针” 接受矩阵,而不是data.frame。其次,任何行或colprods命令都将出现数值溢出错误。所以最好用 Rfast :: colprods(x,method =“ expsumlog”)