具有动态列输入的按行划分的列乘积-矢量化操作

时间:2019-04-11 19:57:28

标签: r

我想将以下代码向量化,以进行更有效的处理。我需要按行取列的乘积(即rowProds),但是我希望乘积的列数必须是另一个输入的函数。

如果可能的话,我希望使用Base R完成此操作,但是我愿意接受并接受任何建议。

这可以轻松地使用循环或带有udf的家庭来完成,但是这些速度不足以满足我的需求。

# Generate some data

mat <- data.frame(X = 1:5)
for (i in 1:5) {
  set.seed(i)
  mat[1 + i] <- runif(5)
}

# Via a for loop

for (i in 1:nrow(mat)) {  
  mat$calc[i] <- prod(mat[match(mat$X[i], mat$X), 2:(i + 1)])
}
mat

# Via a function with mapply

rowprodfun <- function(X) {  
  myprod <- prod(mat[match(X, mat$X), 2:(X + 1)])
  return(myprod)
}

mat$calc <- mapply(rowprodfun, mat$X)
mat

mat$calc
# [1] 0.265508663 0.261370165 0.126427355 0.013874517 0.009758232

以上两种方法均导致相同的“ calc”列。我只需要一种更有效的方法来生成此列。

2 个答案:

答案 0 :(得分:1)

一种选择是将上三角元素转换为NA,然后使用rowProds中的matrixStats

library(matrixStats)
rowProds(as.matrix(mat[-1] * NA^upper.tri(mat[-1])), na.rm = TRUE)
#[1] 0.265508663 0.261370165 0.126427355 0.013874517 0.009758232

答案 1 :(得分:0)

使用@akrun建议的upper.tri很有帮助。最后一步是将data.frame转换为矩阵as.matrix ,然后进行逐元素乘法。

rowProds(as.matrix(mat[-1]) * NA ^ upper.tri(mat[-1]), na.rm = T)导致计算效率最高。

apply(as.matrix(mat[-1]) * NA ^ upper.tri(mat[-1]), 1, prod, na.rm = T)的效率几乎要比基本R高。

library(microbenchmark)
library(matrixStats)
library(ggplot2)

Y <- microbenchmark(
  for.loop = for (i in 1:nrow(mat)) {prod(mat[match(mat$X[i], mat$X), 2:(i + 1)])},
  mapply.fun = mapply(rowprodfun, mat$X),
  rowProds = rowProds(as.matrix(mat[-1] * NA ^ upper.tri(mat[-1])), na.rm = T),
  rowProds.matrix = rowProds(as.matrix(mat[-1]) * NA ^ upper.tri(mat[-1]), na.rm = T),
  apply = apply(mat[-1] * NA ^ upper.tri(mat[-1]), 1, prod, na.rm = T),
  apply.matrix = apply(as.matrix(mat[-1]) * NA ^ upper.tri(mat[-1]), 1, prod, na.rm = T)
)

> Y
Unit: microseconds
            expr      min        lq      mean    median        uq       max neval
        for.loop 4094.869 4305.5590 5682.2124 4479.8125 5193.8190 50361.025   100
      mapply.fun  542.962  577.6995 1036.9821  599.2220  658.1245 32426.296   100
        rowProds  518.419  553.9120  654.2657  597.5225  637.1690  2434.267   100
 rowProds.matrix   99.304  116.1065  144.9313  128.0010  153.8650   516.909   100
           apply  547.493  580.1540  686.2317  628.2955  703.0565  1215.812   100
    apply.matrix  117.051  136.6845  158.3808  144.9920  156.5075   339.068   100

benchmarkautoplot