建议在多个值上优化简单的Scala foldLeft?

时间:2012-02-02 17:00:23

标签: scala functional-programming fold optimization

我正在重新实现一些代码(一种简单的贝叶斯推理算法,但这并不重要)从Java到Scala。我希望以尽可能最高效的方式实现它,同时通过尽可能避免可变性来保持代码的清洁和功能。

以下是Java代码的片段:

    // initialize
    double lP  = Math.log(prior);
    double lPC = Math.log(1-prior);

    // accumulate probabilities from each annotation object into lP and lPC
    for (Annotation annotation : annotations) {
        float prob = annotation.getProbability();
        if (isValidProbability(prob)) {
            lP  += logProb(prob);
            lPC += logProb(1 - prob);
        }
    } 

非常简单,对吧?所以我决定在第一次尝试时使用Scala foldLeft和map方法。由于我有两个正在积累的值,因此累加器是一个元组:

    val initial  = (math.log(prior), math.log(1-prior))
    val probs    = annotations map (_.getProbability)
    val (lP,lPC) = probs.foldLeft(initial) ((r,p) => {
      if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
    })

不幸的是,这段代码的执行速度比Java快5倍(使用简单且不精确的度量标准;只需在循环中调用代码10000次)。一个缺点很明显;我们遍历列表两次,一次是在map中调用,另一次是在foldLeft中。所以这是一个遍历列表的版本。

    val (lP,lPC) = annotations.foldLeft(initial) ((r,annotation) => {
      val  p = annotation.getProbability
      if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
    })

这样更好!它的执行速度比Java代码差3倍。我的下一个预感是,在折叠的每个步骤中创建所有新元组可能需要花费一些成本。所以我决定尝试两次遍历列表的版本,但不创建元组。

    val lP = annotations.foldLeft(math.log(prior)) ((r,annotation) => {
       val  p = annotation.getProbability
       if(isValidProbability(p)) r + logProb(p) else r
    })
    val lPC = annotations.foldLeft(math.log(1-prior)) ((r,annotation) => {
      val  p = annotation.getProbability
      if(isValidProbability(p)) r + logProb(1-p) else r
    })

这与前一版本大致相同(比Java版本慢3倍)。并不奇怪,但我很有希望。

所以我的问题是,是否有更快的方法在Scala中实现这个Java代码段,同时保持Scala代码干净,避免不必要的可变性并遵循Scala惯用法?我确实希望最终在并发环境中使用此代码,因此保持不变性的价值可能会超过单个线程中较慢的性能。

5 个答案:

答案 0 :(得分:4)

首先,您的一些惩罚可能来自您正在使用的收集类型。但是大部分可能是通过运行循环两次实际上无法避免的对象创建,因为必须将数字装箱。

相反,您可以创建一个可变类来为您累积值:

class LogOdds(var lp: Double = 0, var lpc: Double = 0) {
  def *=(p: Double) = {
    if (isValidProbability(p)) {
      lp += logProb(p)
      lpc += logProb(1-p)
    }
    this  // Pass self on so we can fold over the operation
  }
  def toTuple = (lp, lpc)
}

现在虽然你可以不安全地使用它,但你不必这样做。事实上,你可以折叠它。

annotations.foldLeft(new LogOdds()) { (r,ann) => r *= ann.getProbability } toTuple

如果你使用这种模式,所有可变的不安全隐藏在折叠内;它永远不会逃脱。

现在,你不能做一个平行折叠,但你可以做一个聚合,这就像一个折叠,有一个额外的操作来组合碎片。所以你添加方法

def **(lo: LogOdds) = new LogOdds(lp + lo.lp, lpc + lo.lpc)

LogOdds然后

annotations.aggregate(new LogOdds())(
  (r,ann) => r *= ann.getProbability,
  (l,r) => l**r
).toTuple

你会很高兴。

(可以随意使用非数学符号,但由于你基本上是概率乘法,乘法符号似乎更容易给出一个直观的想法,而不是包含概率或某些东西。)

答案 1 :(得分:3)

您可以实现一个尾递归方法,该方法将由编译器转换为while循环,因此应该与Java版本一样快。或者,您可以使用一个循环 - 如果它只是在方法中使用局部变量(例如,请参阅Scala集合源代码中的大量使用),则没有法律可以反对它。

def calc(lst: List[Annotation], lP: Double = 0, lPC: Double = 0): (Double, Double) = {
  if (lst.isEmpty) (lP, lPC)
  else {
    val prob = lst.head.getProbability
    if (isValidProbability(prob)) 
      calc(lst.tail, lP + logProb(prob), lPC + logProb(1 - prob))
    else 
      calc(lst.tail, lP, lPC)
  }
}

折叠的优点是它可以并行化,这可能导致它比多核机器上的Java版本更快(参见其他答案)。

答案 2 :(得分:2)

作为一种旁注:您可以使用view避免以惯用方式遍历列表两次:

val probs = annotations.view.map(_.getProbability).filter(isValidProbability)

val (lP, lPC) = ((logProb(prior), logProb(1 - prior)) /: probs) {
   case ((pa, ca), p) => (pa + logProb(p), ca + logProb(1 - p))
}

这可能不会让你比第三版更好,但对我来说感觉更优雅。

答案 3 :(得分:2)

首先,让我们解决性能问题:除了使用while循环之外,没有办法像Java那样快速地实现它。基本上,JVM无法优化Scala循环到优化Java循环的程度。其原因甚至是JVM民众关注的问题,因为它也妨碍了他们并行的图书馆工作。

现在,回到Scala效果,您还可以使用.view来避免在map步骤中创建新集合,但我认为map步骤将导致性能下降。问题是,您正在将集合转换为Double上的参数化集合,必须将其装箱并取消装箱。

然而,有一种可能的优化方法:使其平行。如果您在.par上致电annotations以使其成为并行收藏品,则可以使用fold

val parAnnot = annotations.par
val lP = parAnnot.map(_.getProbability).fold(math.log(prior)) ((r,p) => {
   if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = parAnnot.map(_.getProbability).fold(math.log(1-prior)) ((r,p) => {
  if(isValidProbability(p)) r + logProb(1-p) else r
})

要避免单独的map步骤,请按照Rex的建议使用aggregate代替fold

对于奖励积分,您可以使用Future使两个计算并行运行。我怀疑你可以通过将元组带回来并一次性运行它来获得更好的性能。你必须对这些东西进行基准测试才能看出哪些更好。

对于并行收集,它可能会因为有效注释而首先filter得到回报。或者,也许是collect

val parAnnot = annottions.par.view map (_.getProbability) filter (isValidProbability(_)) force;

val parAnnot = annotations.par collect { case annot if isValidProbability(annot.getProbability) => annot.getProbability }

无论如何,基准。

答案 4 :(得分:1)

目前无法在没有装箱的情况下与scala集合库进行交互。那么Java中的原始double将在fold操作中不断被装箱和取消装箱,即使你没有将它们包裹在Tuple2中(是< / em>专业 - 但当然你已经付出了每次创建新对象的性能开销。)