矩阵的哪几行等于某个向量

时间:2019-02-11 13:50:39

标签: r apply

我有一段代码可以搜索矩阵boxes的哪些行等于给定的向量x。该代码使用apply函数,我想知道是否可以对其进行更多优化?

x = floor(runif(4)*10)/10
boxes = as.matrix(do.call(expand.grid, lapply(1:4, function(x) {
  seq(0, 1 - 1/10, length = 10)
})))

# can the following line be more optimised ? :
result <- which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))

我自己没有设法摆脱apply函数,但也许您会比我有更好的主意:)

3 个答案:

答案 0 :(得分:5)

一个选项是which(colSums(t(boxes) == x) == ncol(boxes))

向量是按列回收的,因此我们需要先对boxes进行转置,然后再将x==进行比较。然后我们可以选择which列(转置行)的总和为ncol(boxes),即所有TRUE值。

这是该示例的基准(可能不具有代表性)

Irnv <- function() which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
ICT <- function() which(colSums(t(boxes) == x) == ncol(boxes))
RS <- function() which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
RS2 <- function(){ 
  boxes <- data.frame(boxes)
  which(rowSums(mapply(`==`, boxes, x)) == length(x))
}
akrun <- function() which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))


microbenchmark(Irnv(), ICT(), RS(), RS2(), akrun())
# Unit: microseconds
#     expr       min         lq       mean     median         uq       max neval
#   Irnv() 19218.470 20122.2645 24182.2337 21882.8815 24949.1385 66387.719   100
#    ICT()   300.308   323.2830   466.0395   342.3595   430.1545  7878.978   100
#     RS()   566.564   586.2565   742.4252   617.2315   688.2060  8420.927   100
#    RS2()   698.257   772.3090  1017.0427   842.2570   988.9240  9015.799   100
#  akrun()   442.667   453.9490   579.9102   473.6415   534.5645  6870.156   100

答案 1 :(得分:2)

which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
#[1] 5805

使用mapply来回答的问题。

which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
#[1] 5805

如果允许将boxes用作数据帧,则可以简化(仅减少按键,请参阅ICT的基准)以上版本。

boxes <- data.frame(boxes)
which(rowSums(mapply(`==`, boxes, x)) == length(x))
#[1] 5805

我的系统上的基准测试,可以在新的R会话中获得各种答案

Irnv <- function() which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
ICT <- function() which(colSums(t(boxes) == x) == ncol(boxes))
RS <- function() which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
RS2 <- function(){ 
   boxes <- data.frame(boxes)
   which(rowSums(mapply(`==`, boxes, x)) == length(x))
 }
akrun <- function() which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))
akrun2 <- function() which(rowSums(boxes == rep(x, each = nrow(boxes))) == ncol(boxes))
akrun3 <- function() which(rowSums(sweep(boxes, 2, x, `==`)) == ncol(boxes))

library(microbenchmark)
microbenchmark(Irnv(), ICT(), RS(), RS2(), akrun(), akrun2(), akrun3())


#Unit: microseconds
# expr          min         lq       mean     median        uq       max neval
#Irnv()   16335.205 16720.8905 18545.0979 17640.7665 18691.234 49036.793   100
#ICT()      195.068   215.4225   444.9047   233.8600   329.288  4635.817   100
#RS()       527.587   577.1160  1344.3033   639.7180  1373.426 36581.216   100
#RS2()      648.996   737.6870  1810.3805   847.9865  1580.952 35263.632   100
#akrun()    384.498   402.1985   761.0542   421.5025  1176.129  4102.214   100
#akrun2()   840.324   853.9825  1415.9330   883.3730  1017.014 34662.084   100
#akrun3()   399.645   459.7685  1186.7605   488.3345  1215.601 38098.927   100

数据

set.seed(3251)
x = floor(runif(4)*10)/10
boxes = as.matrix(do.call(expand.grid, lapply(1:4, function(x) {
              seq(0, 1 - 1/10, length = 10)
})))

答案 2 :(得分:2)

我们还可以在复制的'x'上使用rowSums以使长度相同

which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))

或使用rep

which(rowSums(boxes == rep(x, each = nrow(boxes))) == ncol(boxes))

或者使用sweeprowSums

which(rowSums(sweep(boxes, 2, x, `==`)) == ncol(boxes))