这个R函数可以被矢量化吗?

时间:2012-10-05 14:53:26

标签: r vectorization conditional-statements

bucketIndex <- function(v, N){
  o <- rep(0, length(v))

  curSum <- 0
  index  <- 1

  for(i in seq(length(v))){
    o[i] <- index

    curSum <- curSum + v[i]
    if(curSum > N){
      curSum <- 0
      index <- index + 1
    }
  }

  o
}

> bucketIndex(c(1, 1, 2, 1, 5, 1), 3)
[1] 1 1 1 2 2 3

我想知道这个功能是否从根本上是不可矢量化的。如果是,是否有一些软件包来处理这个“类”的函数,或者是唯一的替代方法(如果我想要速度)将其写为c扩展名?

3 个答案:

答案 0 :(得分:2)

这是一次尝试(尚未到达bucketIndex!):

  • 你的

    curSum <- curSum + v[i]
    if(curSum > N){
      curSum <- 0
      index <- index + 1
    }  
    

    几乎是%/%的整数除cumsum (v)

  • 但不完全是,即使v [i]是&gt;,您的索引也只会计数1。几次N,你从1开始。我们几乎可以通过转换为一个因子并返回整数来处理这个问题。

  • 但是,我想知道(从函数的名称)这种行为是否真正意图:

    > bucketIndex (c(1, 1, 2, 1, 2, 1, 1, 2, 1, 5, 1), 3)
    [1] 1 1 1 2 2 2 3 3 3 4 5
    > bucketIndex (c(1, 1, 1, 2, 2, 1, 1, 2, 1, 5, 1), 3)
    [1] 1 1 1 1 2 2 2 3 3 3 4
    

    即。只需在v中连续两次输入,就可以在结果中产生不同的最大值。

  • 另一点是,在导致总和为&gt;的元素之后,只计算 N。这意味着结果应该在开头有一个额外的1,最后一个元素应该被删除。

  • 您将curSum重置为0,无论它在N上拍摄多少。因此,对于cumsum (v) > N的所有元素,您需要减去此值,然后查找下一个cumsum (v) > N,依此类推。这减少了与for循环相关的循环迭代次数,但这是否会为您提供次级改进取决于vN的条目(或max (index)上的条目1}}:length (v)比例)。如果这是你的例子中的50%,我认为你不会获得实质性收益。除非他们之间至少有一个重要的顺序,否则我会去inline::cfunction

答案 1 :(得分:0)

我要在这里走出去,说答案是“不”。从本质上讲,你根据当前总和的结果改变你总结的东西。这意味着未来的计算取决于中间计算的结果,矢量化操作无法做到。

答案 2 :(得分:0)

我认为这不是完全可矢量化的,但是@cbeleites通过一次处理整个块(桶)来减少循环中的迭代次数。每次迭代都会查找累积和超过N的位置,将索引分配给该范围,将累积和减少超出N的任何值,并重复直到向量耗尽。其余的是簿记(值的初始化和值的增量)。

bucketIndex2 <- function(v, N) {
    index <- 1
    cs <- cumsum(v)
    bk.old <- 0
    o <- rep(0, length(v))

    repeat {
        bk <- suppressWarnings(min(which(cs > N)))
        o[(bk.old+1):min(bk,length(v))] <- index
        if (bk >= length(v)) break
        cs <- cs - cs[bk]
        index <- index + 1
        bk.old <- bk
    }

    o
}

这与您的各种随机输入功能相匹配:

for (i in 1:200) {
  v <- sample(sample(20,1), sample(50,1)+20, replace=TRUE)
  N <- sample(10,1)
  bi <- bucketIndex(v, N)
  bi2 <- bucketIndex2(v, N)
  if (any(bi != bi2)) {
    print("MISMATCH:")
    dump("v","")
    dump("N","")
  }
}