我有一段代码可以搜索矩阵boxes
的哪些行等于给定的向量x
。该代码使用apply
函数,我想知道是否可以对其进行更多优化?
x = floor(runif(4)*10)/10
boxes = as.matrix(do.call(expand.grid, lapply(1:4, function(x) {
seq(0, 1 - 1/10, length = 10)
})))
# can the following line be more optimised ? :
result <- which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
我自己没有设法摆脱apply
函数,但也许您会比我有更好的主意:)
答案 0 :(得分:5)
一个选项是which(colSums(t(boxes) == x) == ncol(boxes))
。
向量是按列回收的,因此我们需要先对boxes
进行转置,然后再将x
与==
进行比较。然后我们可以选择which
列(转置行)的总和为ncol(boxes)
,即所有TRUE
值。
这是该示例的基准(可能不具有代表性)
Irnv <- function() which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
ICT <- function() which(colSums(t(boxes) == x) == ncol(boxes))
RS <- function() which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
RS2 <- function(){
boxes <- data.frame(boxes)
which(rowSums(mapply(`==`, boxes, x)) == length(x))
}
akrun <- function() which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))
microbenchmark(Irnv(), ICT(), RS(), RS2(), akrun())
# Unit: microseconds
# expr min lq mean median uq max neval
# Irnv() 19218.470 20122.2645 24182.2337 21882.8815 24949.1385 66387.719 100
# ICT() 300.308 323.2830 466.0395 342.3595 430.1545 7878.978 100
# RS() 566.564 586.2565 742.4252 617.2315 688.2060 8420.927 100
# RS2() 698.257 772.3090 1017.0427 842.2570 988.9240 9015.799 100
# akrun() 442.667 453.9490 579.9102 473.6415 534.5645 6870.156 100
答案 1 :(得分:2)
which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
#[1] 5805
使用mapply
来回答的问题。
which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
#[1] 5805
如果允许将boxes
用作数据帧,则可以简化(仅减少按键,请参阅ICT的基准)以上版本。
boxes <- data.frame(boxes)
which(rowSums(mapply(`==`, boxes, x)) == length(x))
#[1] 5805
我的系统上的基准测试,可以在新的R会话中获得各种答案
Irnv <- function() which(sapply(1:nrow(boxes),function(i){all(boxes[i,] == x)}))
ICT <- function() which(colSums(t(boxes) == x) == ncol(boxes))
RS <- function() which(rowSums(mapply(function(i, j) boxes[, i] == j, seq_len(ncol(boxes)), x)) == length(x))
RS2 <- function(){
boxes <- data.frame(boxes)
which(rowSums(mapply(`==`, boxes, x)) == length(x))
}
akrun <- function() which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))
akrun2 <- function() which(rowSums(boxes == rep(x, each = nrow(boxes))) == ncol(boxes))
akrun3 <- function() which(rowSums(sweep(boxes, 2, x, `==`)) == ncol(boxes))
library(microbenchmark)
microbenchmark(Irnv(), ICT(), RS(), RS2(), akrun(), akrun2(), akrun3())
#Unit: microseconds
# expr min lq mean median uq max neval
#Irnv() 16335.205 16720.8905 18545.0979 17640.7665 18691.234 49036.793 100
#ICT() 195.068 215.4225 444.9047 233.8600 329.288 4635.817 100
#RS() 527.587 577.1160 1344.3033 639.7180 1373.426 36581.216 100
#RS2() 648.996 737.6870 1810.3805 847.9865 1580.952 35263.632 100
#akrun() 384.498 402.1985 761.0542 421.5025 1176.129 4102.214 100
#akrun2() 840.324 853.9825 1415.9330 883.3730 1017.014 34662.084 100
#akrun3() 399.645 459.7685 1186.7605 488.3345 1215.601 38098.927 100
数据
set.seed(3251)
x = floor(runif(4)*10)/10
boxes = as.matrix(do.call(expand.grid, lapply(1:4, function(x) {
seq(0, 1 - 1/10, length = 10)
})))
答案 2 :(得分:2)
我们还可以在复制的'x'上使用rowSums
以使长度相同
which(rowSums((boxes == x[col(boxes)])) == ncol(boxes))
或使用rep
which(rowSums(boxes == rep(x, each = nrow(boxes))) == ncol(boxes))
或者使用sweep
和rowSums
which(rowSums(sweep(boxes, 2, x, `==`)) == ncol(boxes))