了解Scala中的tailrec带注释的递归方法的性能

时间:2018-07-26 03:14:53

标签: scala

请考虑以下方法-已被验证符合正确的tail recursion

  @tailrec
  def getBoundaries(grps: Seq[(BigDecimal, Int)], groupSize: Int, curSum: Int = 0, curOffs: Seq[BigDecimal] = Seq.empty[BigDecimal]): Seq[BigDecimal] = {
    if (grps.isEmpty) curOffs
    else {
      val (id, cnt) = grps.head
      val newSum = curSum + cnt.toInt
      if (newSum%50==0) { println(s"id=$id newsum=$newSum") }
      if (newSum >= groupSize) {
        getBoundaries(grps.tail, groupSize, 0, curOffs :+ id) // r1
      } else {
        getBoundaries(grps.tail, groupSize, newSum, curOffs) // r2
      }
    }
  }

运行速度非常-大约每秒75个循环。当我点击stacktrace(Intellij的一个不错的功能)时,几乎每次当前调用的行都是第二个尾递归调用r2。这个事实使我对所谓的“ scala将递归调用分解为while循环”感到怀疑。如果正在发生解包,那么为什么调用本身会花费这么多时间?

除了具有适当结构的尾部递归方法外,还有其他考虑来使递归例程的性能接近直接迭代吗?

2 个答案:

答案 0 :(得分:2)

效果取决于Seq的基础类型。

如果它是List,则问题是将(:+)附加到List上,因为长列表会变得很慢,因为它必须扫描整个列表以找到结尾。

一种解决方案是每次将 放在列表(+:之前,然后在reverse末尾添加。这可以大大提高性能,因为添加到列表的开头非常快。

其他Seq类型将具有不同的性能特征,但是您可以在递归调用之前将其转换为List,以了解其性能。


这是示例代码

def getBoundaries(grps: Seq[(BigDecimal, Int)], groupSize: Int): Seq[BigDecimal] = {
  @tailrec
  def loop(grps: List[(BigDecimal, Int)], curSum: Int, curOffs: List[BigDecimal]): List[BigDecimal] =
    if (grps.isEmpty) curOffs
    else {
      val (id, cnt) = grps.head
      val newSum = curSum + cnt.toInt

      if (newSum >= groupSize) {
        loop(grps.tail, 0, id +: curOffs) // r1
      } else {
        loop(grps.tail, newSum, curOffs) // r2
      }
    }

  loop(grps.toList, 0, Nil).reverse
}

使用发问者自己对问题的答案中提供的测试数据,此版本的性能比原始代码提高了10倍。

答案 1 :(得分:0)

问题不在递归中,而是在数组操作中。使用以下测试用例,它以每秒 200K递归

的速度运行
  type Fgroups = Seq[(BigDecimal, Int)]
  test("testGetBoundaries") {
    val N = 200000
    val grps: Fgroups = (N to 1 by -1).flatMap { x => Array.tabulate(x % 20){ x2 => (BigDecimal(x2 * 1e9), 1) }}
    val sgrps = grps.sortWith { case (a, b) =>
      a._1.longValue.compare(b._1.longValue) < 0
    }
    val bb = getBoundaries(sgrps, 100 )
    println(bb.take(math.min(50,bb.length)).mkString(","))
    assert(bb.length==1900)
  }

我的生产数据样本的条目数量相似(Array有233K行),但运行速度却慢了 3个数量级。我现在正在调查 tail 操作和其他罪魁祸首。

更新来自Alvin Alexander的以下参考信息指出,对于不可变集合,tail操作应该是v快速的-但长时间< em>可变的-包括数组的!

https://alvinalexander.com/scala/understanding-performance-scala-collections-classes-methods-cookbook

enter image description here

enter image description here

哇!我不知道在scala中使用mutable集合会对性能产生影响!

更新,通过添加代码将Array转换为(immutableSeq,我发现生产数据样本的性能提高了3个数量级:

val grps = if (grpsIn.isInstanceOf[mutable.WrappedArray[_]] || grpsIn.isInstanceOf[Array[_]]) {
  Seq(grpsIn: _*)
} else grpsIn

(现在 fast 〜200K / sec)的最终代码是:

  type Fgroups = Seq[(BigDecimal, Int)]  
  val cntr = new java.util.concurrent.atomic.AtomicInteger
  @tailrec
  def getBoundaries(grpsIn: Fgroups, groupSize: Int, curSum: Int = 0, curOffs: Seq[BigDecimal] = Seq.empty[BigDecimal]): Seq[BigDecimal] = {
    val grps = if (grpsIn.isInstanceOf[mutable.WrappedArray[_]] || grpsIn.isInstanceOf[Array[_]]) {
      Seq(grpsIn: _*)
    } else grpsIn

    if (grps.isEmpty) curOffs
    else {
      val (id, cnt) = grps.head
      val newSum = curSum + cnt.toInt
      if (cntr.getAndIncrement % 500==0) { println(s"[${cntr.get}] id=$id newsum=$newSum") }
      if (newSum >= groupSize) {
        getBoundaries(grps.tail, groupSize, 0, curOffs :+ id)
      } else {
        getBoundaries(grps.tail, groupSize, newSum, curOffs)
      }
    }
  }