我正在重新实现一些代码(一种简单的贝叶斯推理算法,但这并不重要)从Java到Scala。我希望以尽可能最高效的方式实现它,同时通过尽可能避免可变性来保持代码的清洁和功能。
以下是Java代码的片段:
// initialize
double lP = Math.log(prior);
double lPC = Math.log(1-prior);
// accumulate probabilities from each annotation object into lP and lPC
for (Annotation annotation : annotations) {
float prob = annotation.getProbability();
if (isValidProbability(prob)) {
lP += logProb(prob);
lPC += logProb(1 - prob);
}
}
非常简单,对吧?所以我决定在第一次尝试时使用Scala foldLeft和map方法。由于我有两个正在积累的值,因此累加器是一个元组:
val initial = (math.log(prior), math.log(1-prior))
val probs = annotations map (_.getProbability)
val (lP,lPC) = probs.foldLeft(initial) ((r,p) => {
if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
})
不幸的是,这段代码的执行速度比Java快5倍(使用简单且不精确的度量标准;只需在循环中调用代码10000次)。一个缺点很明显;我们遍历列表两次,一次是在map中调用,另一次是在foldLeft中。所以这是一个遍历列表的版本。
val (lP,lPC) = annotations.foldLeft(initial) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
})
这样更好!它的执行速度比Java代码差3倍。我的下一个预感是,在折叠的每个步骤中创建所有新元组可能需要花费一些成本。所以我决定尝试两次遍历列表的版本,但不创建元组。
val lP = annotations.foldLeft(math.log(prior)) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = annotations.foldLeft(math.log(1-prior)) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) r + logProb(1-p) else r
})
这与前一版本大致相同(比Java版本慢3倍)。并不奇怪,但我很有希望。
所以我的问题是,是否有更快的方法在Scala中实现这个Java代码段,同时保持Scala代码干净,避免不必要的可变性并遵循Scala惯用法?我确实希望最终在并发环境中使用此代码,因此保持不变性的价值可能会超过单个线程中较慢的性能。
答案 0 :(得分:4)
首先,您的一些惩罚可能来自您正在使用的收集类型。但是大部分可能是通过运行循环两次实际上无法避免的对象创建,因为必须将数字装箱。
相反,您可以创建一个可变类来为您累积值:
class LogOdds(var lp: Double = 0, var lpc: Double = 0) {
def *=(p: Double) = {
if (isValidProbability(p)) {
lp += logProb(p)
lpc += logProb(1-p)
}
this // Pass self on so we can fold over the operation
}
def toTuple = (lp, lpc)
}
现在虽然你可以不安全地使用它,但你不必这样做。事实上,你可以折叠它。
annotations.foldLeft(new LogOdds()) { (r,ann) => r *= ann.getProbability } toTuple
如果你使用这种模式,所有可变的不安全隐藏在折叠内;它永远不会逃脱。
现在,你不能做一个平行折叠,但你可以做一个聚合,这就像一个折叠,有一个额外的操作来组合碎片。所以你添加方法
def **(lo: LogOdds) = new LogOdds(lp + lo.lp, lpc + lo.lpc)
到LogOdds
然后
annotations.aggregate(new LogOdds())(
(r,ann) => r *= ann.getProbability,
(l,r) => l**r
).toTuple
你会很高兴。
(可以随意使用非数学符号,但由于你基本上是概率乘法,乘法符号似乎更容易给出一个直观的想法,而不是包含概率或某些东西。)
答案 1 :(得分:3)
您可以实现一个尾递归方法,该方法将由编译器转换为while循环,因此应该与Java版本一样快。或者,您可以使用一个循环 - 如果它只是在方法中使用局部变量(例如,请参阅Scala集合源代码中的大量使用),则没有法律可以反对它。
def calc(lst: List[Annotation], lP: Double = 0, lPC: Double = 0): (Double, Double) = {
if (lst.isEmpty) (lP, lPC)
else {
val prob = lst.head.getProbability
if (isValidProbability(prob))
calc(lst.tail, lP + logProb(prob), lPC + logProb(1 - prob))
else
calc(lst.tail, lP, lPC)
}
}
折叠的优点是它可以并行化,这可能导致它比多核机器上的Java版本更快(参见其他答案)。
答案 2 :(得分:2)
作为一种旁注:您可以使用view
避免以惯用方式遍历列表两次:
val probs = annotations.view.map(_.getProbability).filter(isValidProbability)
val (lP, lPC) = ((logProb(prior), logProb(1 - prior)) /: probs) {
case ((pa, ca), p) => (pa + logProb(p), ca + logProb(1 - p))
}
这可能不会让你比第三版更好,但对我来说感觉更优雅。
答案 3 :(得分:2)
首先,让我们解决性能问题:除了使用while循环之外,没有办法像Java那样快速地实现它。基本上,JVM无法优化Scala循环到优化Java循环的程度。其原因甚至是JVM民众关注的问题,因为它也妨碍了他们并行的图书馆工作。
现在,回到Scala效果,您还可以使用.view
来避免在map
步骤中创建新集合,但我认为map
步骤将导致性能下降。问题是,您正在将集合转换为Double
上的参数化集合,必须将其装箱并取消装箱。
然而,有一种可能的优化方法:使其平行。如果您在.par
上致电annotations
以使其成为并行收藏品,则可以使用fold
:
val parAnnot = annotations.par
val lP = parAnnot.map(_.getProbability).fold(math.log(prior)) ((r,p) => {
if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = parAnnot.map(_.getProbability).fold(math.log(1-prior)) ((r,p) => {
if(isValidProbability(p)) r + logProb(1-p) else r
})
要避免单独的map
步骤,请按照Rex的建议使用aggregate
代替fold
。
对于奖励积分,您可以使用Future
使两个计算并行运行。我怀疑你可以通过将元组带回来并一次性运行它来获得更好的性能。你必须对这些东西进行基准测试才能看出哪些更好。
对于并行收集,它可能会因为有效注释而首先filter
得到回报。或者,也许是collect
。
val parAnnot = annottions.par.view map (_.getProbability) filter (isValidProbability(_)) force;
或
val parAnnot = annotations.par collect { case annot if isValidProbability(annot.getProbability) => annot.getProbability }
无论如何,基准。
答案 4 :(得分:1)
目前无法在没有装箱的情况下与scala集合库进行交互。那么Java中的原始double
将在fold
操作中不断被装箱和取消装箱,即使你没有将它们包裹在Tuple2
中(是< / em>专业 - 但当然你已经付出了每次创建新对象的性能开销。)