我有一个函数,在给定一个简单的node.id,node.parentId关联的情况下,计算某些treeNodes集合的左右节点值。它非常简单并且运作良好......但是,我想知道是否有更惯用的方法。具体来说,有一种方法可以跟踪左/右值而不使用一些外部跟踪的值,但仍保持美味的递归。
/*
* A tree node
*/
case class TreeNode(val id:String, val parentId: String){
var left: Int = 0
var right: Int = 0
}
/*
* a method to compute the left/right node values
*/
def walktree(node: TreeNode) = {
/*
* increment state for the inner function
*/
var c = 0
/*
* A method to set the increment state
*/
def increment = { c+=1; c } // poo
/*
* the tasty inner method
* treeNodes is a List[TreeNode]
*/
def walk(node: TreeNode): Unit = {
node.left = increment
/*
* recurse on all direct descendants
*/
treeNodes filter( _.parentId == node.id) foreach (walk(_))
node.right = increment
}
walk(node)
}
walktree(someRootNode)
编辑 - 节点列表取自数据库。将节点拉入适当的树将花费太多时间。我正在将一个平面列表拉入内存,我所拥有的是一个通过节点ID的关联,与父母和孩子有关。
添加左/右节点值允许我使用单个SQL查询获取所有子项(和子项的子项)的快照商店。
如果父子关联发生变化(他们经常这样做),计算需要非常快速地运行以保持数据的完整性。
除了使用令人敬畏的Scala集合之外,我还通过在树节点上使用并行处理进行一些前/后过滤来提高速度。我想找到一种跟踪左/右节点值的更惯用的方法。看完@dhg的答案后,它变得更好了。使用groupBy而不是过滤器会使算法(主要是?)线性而不是四角形!
val treeNodeMap = treeNodes.groupBy(_.parentId).withDefaultValue(Nil)
def walktree(node: TreeNode) = {
def walk(node: TreeNode, counter: Int): Int = {
node.left = counter
node.right =
treeNodeMap(node.id)
.foldLeft(counter+1) {
(result, curnode) => walk(curnode, result) + 1
}
node.right
}
walk(node,1)
}
答案 0 :(得分:6)
您的代码似乎正在计算有序遍历编号。
我认为您希望使代码更好的是fold
向下传递当前值并向上传递更新后的值。请注意,在treeNodes.groupBy(_.parentId)
之前执行walktree
可能还是值得的,以防止您每次拨打treeNodes.filter(...)
时都拨打walk
。
val treeNodes = List(TreeNode("1","0"),TreeNode("2","1"),TreeNode("3","1"))
val treeNodeMap = treeNodes.groupBy(_.parentId).withDefaultValue(Nil)
def walktree2(node: TreeNode) = {
def walk(node: TreeNode, c: Int): Int = {
node.left = c
val newC =
treeNodeMap(node.id) // get the children without filtering
.foldLeft(c+1)((c, child) => walk(child, c) + 1)
node.right = newC
newC
}
walk(node, 1)
}
它产生相同的结果:
scala> walktree2(TreeNode("0","-1"))
scala> treeNodes.map(n => "(%s,%s)".format(n.left,n.right))
res32: List[String] = List((2,7), (3,4), (5,6))
那就是说,我会完全重写你的代码如下:
case class TreeNode( // class is now immutable; `walktree` returns a new tree
id: String,
value: Int, // value to be set during `walktree`
left: Option[TreeNode], // recursively-defined structure
right: Option[TreeNode]) // makes traversal much simpler
def walktree(node: TreeNode) = {
def walk(nodeOption: Option[TreeNode], c: Int): (Option[TreeNode], Int) = {
nodeOption match {
case None => (None, c) // if this child doesn't exist, do nothing
case Some(node) => // if this child exists, recursively walk
val (newLeft, cLeft) = walk(node.left, c) // walk the left side
val newC = cLeft + 1 // update the value
val (newRight, cRight) = walk(node.right, newC) // walk the right side
(Some(TreeNode(node.id, newC, newLeft, newRight)), cRight)
}
}
walk(Some(node), 0)._1
}
然后你可以像这样使用它:
walktree(
TreeNode("1", -1,
Some(TreeNode("2", -1,
Some(TreeNode("3", -1, None, None)),
Some(TreeNode("4", -1, None, None)))),
Some(TreeNode("5", -1, None, None))))
生产:
Some(TreeNode(1,4,
Some(TreeNode(2,2,
Some(TreeNode(3,1,None,None)),
Some(TreeNode(4,3,None,None)))),
Some(TreeNode(5,5,None,None))))
答案 1 :(得分:1)
如果我正确地得到你的算法:
def walktree(node: TreeNode, c: Int): Int = {
node.left = c
val c2 = treeNodes.filter(_.parentId == node.id).foldLeft(c + 1) {
(cur, n) => walktree(n, cur)
}
node.right = c2 + 1
c2 + 2
}
walktree(new TreeNode("", ""), 0)
可能会发生逐个错误。
几个随机的想法(更适合http://codereview.stackexchange.com):
尝试发布编译...我们必须猜测这是一系列TreeNode
:
val
隐含于case
类:
case class TreeNode(val id: String, val parentId: String) {
对=
函数避免明确Unit
和Unit
:
def walktree(node: TreeNode) = {
def walk(node: TreeNode): Unit = {
有副作用的方法应该有()
:
def increment = {c += 1; c}
这非常慢,考虑在实际节点中存储子项列表:
treeNodes filter (_.parentId == node.id) foreach (walk(_))
更简洁的语法是treeNodes foreach walk
:
treeNodes foreach (walk(_))