我有这个R函数来生成一个矩阵,其中包含0到n之间k个数的所有组合,其总和等于n。这是我的程序的瓶颈之一,因为它变得非常慢,即使数量很少(因为它正在计算功率集)
这是代码
sum.comb <-
function(n,k) {
ls1 <- list() # generate empty list
for(i in 1:k) { # how could this be done with apply?
ls1[[i]] <- 0:n # fill with 0:n
}
allc <- as.matrix(expand.grid(ls1)) # generate all combinations, already using the built in function
colnames(allc) <- NULL
index <- (rowSums(allc) == n) # make index with only the ones that sum to n
allc[index, ,drop=F] # matrix with only the ones that sum to n
}
答案 0 :(得分:4)
除非您回答关于n
和k
的典型值的问题,否则很难判断它是否有用(请执行。)这是使用递归的版本,似乎更快比josilber使用他的基准测试:
sum.comb3 <- function(n, k) {
stopifnot(k > 0L)
REC <- function(n, k) {
if (k == 1L) list(n) else
unlist(lapply(0:n, function(i)Map(c, i, REC(n - i, k - 1L))),
recursive = FALSE)
}
matrix(unlist(REC(n, k)), ncol = k, byrow = TRUE)
}
microbenchmark(sum.comb(3, 10), sum.comb2(3, 10), sum.comb3(3, 10))
# Unit: milliseconds
# expr min lq median uq max neval
# sum.comb2(3, 10) 39.55612 40.60798 41.91954 44.26756 70.44944 100
# sum.comb3(3, 10) 25.86008 27.74415 28.37080 29.65567 34.18620 100
答案 1 :(得分:3)
这是一种不同的方法,它在每次迭代中逐步将集合从大小1扩展到k,从而修剪总和超过n的组合。这会导致你有一个相对于n的大k的加速,因为你不需要计算任何接近功率集大小的东西。
sum.comb2 <- function(n, k) {
combos <- 0:n
sums <- 0:n
for (width in 2:k) {
combos <- apply(expand.grid(combos, 0:n), 1, paste, collapse=" ")
sums <- apply(expand.grid(sums, 0:n), 1, sum)
if (width == k) {
return(combos[sums == n])
} else {
combos <- combos[sums <= n]
sums <- sums[sums <= n]
}
}
}
# Simple test
sum.comb2(3, 2)
# [1] "3 0" "2 1" "1 2" "0 3"
以下是小n和大k的加速的示例:
library(microbenchmark)
microbenchmark(sum.comb2(1, 100))
# Unit: milliseconds
# expr min lq median uq max neval
# sum.comb2(1, 100) 149.0392 158.716 162.1919 174.0482 236.2095 100
这种方法在不到一秒的时间内运行,而功率集的方法当然永远不会超过expand.grid
的调用,因为你的结果矩阵最终会有2 ^ 100行
即使在一个不太极端的情况下,n = 3和k = 10,我们看到与原始帖子中的函数相比增加了20倍:
microbenchmark(sum.comb(3, 10), sum.comb2(3, 10))
# Unit: milliseconds
# expr min lq median uq max neval
# sum.comb(3, 10) 404.00895 439.94472 446.67452 461.24909 574.80426 100
# sum.comb2(3, 10) 23.27445 24.53771 25.60409 26.97439 65.59576 100
答案 2 :(得分:2)
请参阅partitions
软件包({1}}和compositions()
,它们作为整个矩阵生成器和迭代操作都会更快。然后,如果仍然不够快,请参阅有关组合和分区生成算法(无环路,格雷码和并行)的各种出版物,如Daniel Page's research。
blockparts()
答案 3 :(得分:1)
以下可以用lapply完成。
ls1 <- list()
for(i in 1:k) {
ls1[[i]] <- 0:n
}
尝试替换这是,看看你是否加速。
ls1 = lapply(1:k,function(x) 0:n)
我将'ls'改为'ls1',因为ls()是一个R函数。
答案 4 :(得分:1)
如此简短:
comb = function(n, k) {
all = combn(0:n, k)
sums = colSums(all)
all[, sums == n]
}
然后像:
comb(5, 3)
根据您的要求生成矩阵:
[,1] [,2]
[1,] 0 0
[2,] 1 2
[3,] 4 3
感谢@josilber和原始海报,指出OP需要所有排列重复而不是组合。排列的类似方法如下:
perm = function(n, k) {
grid = matrix(rep(0:n, k), n + 1, k)
all = expand.grid(data.frame(grid))
sums = rowSums(all)
all[sums == n,]
}
然后像:
perm(5, 3)
根据您的要求生成矩阵:
X1 X2 X3
6 5 0 0
11 4 1 0
16 3 2 0
21 2 3 0
26 1 4 0
31 0 5 0
...