请考虑以下方法-已被验证符合正确的tail recursion
:
@tailrec
def getBoundaries(grps: Seq[(BigDecimal, Int)], groupSize: Int, curSum: Int = 0, curOffs: Seq[BigDecimal] = Seq.empty[BigDecimal]): Seq[BigDecimal] = {
if (grps.isEmpty) curOffs
else {
val (id, cnt) = grps.head
val newSum = curSum + cnt.toInt
if (newSum%50==0) { println(s"id=$id newsum=$newSum") }
if (newSum >= groupSize) {
getBoundaries(grps.tail, groupSize, 0, curOffs :+ id) // r1
} else {
getBoundaries(grps.tail, groupSize, newSum, curOffs) // r2
}
}
}
运行速度非常-大约每秒75个循环。当我点击stacktrace(Intellij
的一个不错的功能)时,几乎每次当前调用的行都是第二个尾递归调用r2
。这个事实使我对所谓的“ scala将递归调用分解为while
循环”感到怀疑。如果正在发生解包,那么为什么调用本身会花费这么多时间?
除了具有适当结构的尾部递归方法外,还有其他考虑来使递归例程的性能接近直接迭代吗?
答案 0 :(得分:2)
效果取决于Seq
的基础类型。
如果它是List
,则问题是将(:+
)附加到List
上,因为长列表会变得很慢,因为它必须扫描整个列表以找到结尾。
一种解决方案是每次将 放在列表(+:
之前,然后在reverse
末尾添加。这可以大大提高性能,因为添加到列表的开头非常快。
其他Seq
类型将具有不同的性能特征,但是您可以在递归调用之前将其转换为List
,以了解其性能。
这是示例代码
def getBoundaries(grps: Seq[(BigDecimal, Int)], groupSize: Int): Seq[BigDecimal] = {
@tailrec
def loop(grps: List[(BigDecimal, Int)], curSum: Int, curOffs: List[BigDecimal]): List[BigDecimal] =
if (grps.isEmpty) curOffs
else {
val (id, cnt) = grps.head
val newSum = curSum + cnt.toInt
if (newSum >= groupSize) {
loop(grps.tail, 0, id +: curOffs) // r1
} else {
loop(grps.tail, newSum, curOffs) // r2
}
}
loop(grps.toList, 0, Nil).reverse
}
使用发问者自己对问题的答案中提供的测试数据,此版本的性能比原始代码提高了10倍。
答案 1 :(得分:0)
问题不在递归中,而是在数组操作中。使用以下测试用例,它以每秒 200K递归
的速度运行 type Fgroups = Seq[(BigDecimal, Int)]
test("testGetBoundaries") {
val N = 200000
val grps: Fgroups = (N to 1 by -1).flatMap { x => Array.tabulate(x % 20){ x2 => (BigDecimal(x2 * 1e9), 1) }}
val sgrps = grps.sortWith { case (a, b) =>
a._1.longValue.compare(b._1.longValue) < 0
}
val bb = getBoundaries(sgrps, 100 )
println(bb.take(math.min(50,bb.length)).mkString(","))
assert(bb.length==1900)
}
我的生产数据样本的条目数量相似(Array
有233K行),但运行速度却慢了 3个数量级。我现在正在调查 tail 操作和其他罪魁祸首。
更新来自Alvin Alexander
的以下参考信息指出,对于不可变集合,tail
操作应该是v快速的-但长时间< em>可变的-包括数组的!
哇!我不知道在scala中使用mutable
集合会对性能产生影响!
更新,通过添加代码将Array
转换为(immutable
)Seq
,我发现生产数据样本的性能提高了3个数量级:
val grps = if (grpsIn.isInstanceOf[mutable.WrappedArray[_]] || grpsIn.isInstanceOf[Array[_]]) {
Seq(grpsIn: _*)
} else grpsIn
(现在 fast 〜200K / sec)的最终代码是:
type Fgroups = Seq[(BigDecimal, Int)]
val cntr = new java.util.concurrent.atomic.AtomicInteger
@tailrec
def getBoundaries(grpsIn: Fgroups, groupSize: Int, curSum: Int = 0, curOffs: Seq[BigDecimal] = Seq.empty[BigDecimal]): Seq[BigDecimal] = {
val grps = if (grpsIn.isInstanceOf[mutable.WrappedArray[_]] || grpsIn.isInstanceOf[Array[_]]) {
Seq(grpsIn: _*)
} else grpsIn
if (grps.isEmpty) curOffs
else {
val (id, cnt) = grps.head
val newSum = curSum + cnt.toInt
if (cntr.getAndIncrement % 500==0) { println(s"[${cntr.get}] id=$id newsum=$newSum") }
if (newSum >= groupSize) {
getBoundaries(grps.tail, groupSize, 0, curOffs :+ id)
} else {
getBoundaries(grps.tail, groupSize, newSum, curOffs)
}
}
}