使用Scala进行N树遍历会导致堆栈溢出

时间:2012-10-23 22:39:10

标签: scala recursion tree traversal

我试图从N树数据结构返回一个小部件列表。在我的单元测试中,如果我有大约2000个小部件,每个小部件都有一个依赖,我会遇到堆栈溢出。我认为正在发生的是for循环导致我的树遍历不是尾递归。什么是在scala中写这个的更好方法?这是我的功能:

protected def getWidgetTree(key: String) : ListBuffer[Widget] = {
  def traverseTree(accumulator: ListBuffer[Widget], current: Widget) : ListBuffer[Widget] = {
    accumulator.append(current)

    if (!current.hasDependencies) {
      accumulator
    }  else {
      for (dependencyKey <- current.dependencies) {
        if (accumulator.findIndexOf(_.name == dependencyKey) == -1) {
          traverseTree(accumulator, getWidget(dependencyKey))
        }
      }

      accumulator
    }
  }

  traverseTree(ListBuffer[Widget](), getWidget(key))
}

3 个答案:

答案 0 :(得分:10)

它不是尾递归的原因是你在函数内部进行了多次递归调用。要进行尾递归,递归调用只能是函数体中的最后一个表达式。毕竟,重点是它像while循环一样工作(因此,可以转换为循环)。循环不能在一次迭代中多次调用自身。

要像这样进行树遍历,您可以使用队列来转发需要访问的节点。

假设我们有这棵树:

//        1
//       / \  
//      2   5
//     / \
//    3   4

用这个简单的数据结构代表:

case class Widget(name: String, dependencies: List[String]) {
  def hasDependencies = dependencies.nonEmpty
}

我们将此地图指向每个节点:

val getWidget = List(
  Widget("1", List("2", "5")),
  Widget("2", List("3", "4")),
  Widget("3", List()),
  Widget("4", List()),
  Widget("5", List()))
  .map { w => w.name -> w }.toMap

现在我们可以将您的方法重写为tail-recursive:

def getWidgetTree(key: String): List[Widget] = {
  @tailrec
  def traverseTree(queue: List[String], accumulator: List[Widget]): List[Widget] = {
    queue match {
      case currentKey :: queueTail =>        // the queue is not empty
        val current = getWidget(currentKey)  // get the element at the front
        val newQueueItems =                  // filter out the dependencies already known
          current.dependencies.filterNot(dependencyKey => 
            accumulator.exists(_.name == dependencyKey) && !queue.contains(dependencyKey))
        traverseTree(newQueueItems ::: queueTail, current :: accumulator) // 
      case Nil =>                            // the queue is empty
        accumulator.reverse                  // we're done
    }
  }

  traverseTree(key :: Nil, List[Widget]())
}

并测试出来:

for (k <- 1 to 5)
  println(getWidgetTree(k.toString).map(_.name))

打印:

ListBuffer(1, 2, 3, 4, 5)
ListBuffer(2, 3, 4)
ListBuffer(3)
ListBuffer(4)
ListBuffer(5)

答案 1 :(得分:4)

对于@ dhg的答案中的相同示例,没有可变状态(ListBuffer)的等效尾递归函数将是:

case class Widget(name: String, dependencies: List[String])

val getWidget = List(
  Widget("1", List("2", "5")),
  Widget("2", List("3", "4")),
  Widget("3", List()),
  Widget("4", List()),
  Widget("5", List())).map { w => w.name -> w }.toMap

def getWidgetTree(key: String): List[Widget] = {
  def addIfNotAlreadyContained(widgetList: List[Widget], widgetNameToAdd: String): List[Widget] = {
    if (widgetList.find(_.name == widgetNameToAdd).isDefined) widgetList
    else                                                      widgetList :+ getWidget(widgetNameToAdd)
  }

  @tailrec
  def traverseTree(currentWidgets: List[Widget], acc: List[Widget]): List[Widget] = currentWidgets match {
    case Nil                                => {
      // If there are no more widgets in this branch return what we've traversed so far
      acc 
    }
    case Widget(name, Nil) :: rest          => {
      // If the first widget is a leaf traverse the rest and add the leaf to the list of traversed
      traverseTree(rest, addIfNotAlreadyContained(acc, name)) 
    }
    case Widget(name, dependencies) :: rest => {
      // If the first widget is a parent, traverse it's children and the rest and add it to the list of traversed
      traverseTree(dependencies.map(getWidget) ++ rest, addIfNotAlreadyContained(acc, name))
    } 
  }

  val root = getWidget(key)
  traverseTree(root.dependencies.map(getWidget) :+ root, List[Widget]())
}

对于相同的测试用例

for (k <- 1 to 5)
  println(getWidgetTree(k.toString).map(_.name).toList.sorted)

给你:

List(2, 3, 4, 5, 1)
List(3, 4, 2)
List(3)
List(4)
List(5)

请注意,这是postorder而不是preorder遍历。

答案 2 :(得分:1)

真棒!谢谢。我不知道@tailrec注释。那是一个非常酷的小宝石。我不得不稍微调整解决方案因为带有自引用的小部件导致无限循环。当对traverseTree的调用期望一个List时,newQueueItems也是一个Iterable,所以我不得不把那个位列出来。

def getWidgetTree(key: String): List[Widget] = {
  @tailrec
  def traverseTree(queue: List[String], accumulator: List[Widget]): List[Widget] = {
    queue match {
      case currentKey :: queueTail =>        // the queue is not empty
        val current = getWidget(currentKey)  // get the element at the front
        val newQueueItems =                  // filter out the dependencies already known
          current.dependencies.filter(dependencyKey =>
            !accumulator.exists(_.name == dependencyKey) && !queue.contains(dependencyKey)).toList
        traverseTree(newQueueItems ::: queueTail, current :: accumulator) //
      case Nil =>                            // the queue is empty
        accumulator.reverse                  // we're done
    }
  }

  traverseTree(key :: Nil, List[Widget]())
}