我有一个关于在函数式中编写递归算法的问题。我将在这里使用Scala作为示例,但该问题适用于任何函数式语言。
我正在对 n -ary树进行深度优先枚举,其中每个节点都有一个标签和可变数量的子节点。这是一个简单的实现,它打印叶节点的标签。
case class Node[T](label:T, ns:Node[T]*)
def dfs[T](r:Node[T]):Seq[T] = {
if (r.ns.isEmpty) Seq(r.label) else for (n<-r.ns;c<-dfs(n)) yield c
}
val r = Node('a, Node('b, Node('d), Node('e, Node('f))), Node('c))
dfs(r) // returns Seq[Symbol] = ArrayBuffer('d, 'f, 'c)
现在说有时我希望能够通过抛出异常来放弃解析超大树。这可能是一种功能语言吗?具体是这可能不使用可变状态吗?这似乎取决于你所说的“超大”。这是算法的纯函数版本,当它尝试处理深度为3或更大的树时抛出异常。
def dfs[T](r:Node[T], d:Int = 0):Seq[T] = {
require(d < 3)
if (r.ns.isEmpty) Seq(r.label) else for (n<-r.ns;c<-dfs(n, d+1)) yield c
}
但是,如果一棵树太大而因为太宽而不是太深,会怎么样呢?具体来说,如果我想抛出异常 n 次,dfs()
函数是递归调用的,无论递归有多深?我能看到如何做到这一点的唯一方法是拥有一个可变计数器,该计数器随每次调用递增。没有可变变量,我无法看到如何做到这一点。
我是函数式编程的新手,并且一直在假设你可以用可变状态做任何事情都可以完成,但是我没有在这里看到答案。我唯一能想到的就是编写一个dfs()
版本,它以深度优先的顺序返回树中所有节点的视图。
dfs[T](r:Node[T]):TraversableView[T, Traversable[_]] = ...
然后我可以通过说dfs(r).take(n)
强加我的限制,但我不知道如何编写这个函数。在Python中,我只是在访问节点时通过yield
节点创建一个生成器,但我不知道如何在Scala中实现相同的效果。 (Scala等同于Python风格的yield
语句似乎是作为参数传入的访问者函数,但我无法弄清楚如何编写其中一个将生成序列视图的函数。)
编辑接近答案。
这是一个以深度优先顺序返回Stream
个节点的函数。
def dfs[T](r: Node[T]): Stream[Node[T]] = {
(r #:: Stream.empty /: r.ns)(_ ++ dfs(_))
}
差不多就是这样。唯一的问题是Stream
会记住所有结果,这会浪费内存。我想要一个可穿越的视图。以下是这个想法,但不编译。
def dfs[T](r: Node[T]): TraversableView[Node[T], Traversable[Node[T]]] = {
(Traversable(r).view /: r.ns)(_ ++ dfs(_))
}
它为TraversableView[Node[T], Traversable[Node[T]]]
运算符提供了“找到TraversableView[Node[T], Traversable[_]]
,必需++
错误。如果我将返回类型更改为TraversableView[Node[T], Traversable[_]]
,我会遇到同样的问题“找到”和“必需”条款已经切换。所以有一些魔法类型的方差咒语我还没有点亮,但这很接近。
答案 0 :(得分:7)
可以这样做:你只需编写一些代码就可以按照你想要的方式实际迭代孩子(而不是依赖for
)。
更明确地说,您必须编写代码来遍历子项列表并检查“深度”是否超过了您的阈值。这里有一些Haskell
代码(我真的很抱歉,我不熟悉Scala,但这可能很容易被音译):
在这段代码中,我基本上用for
循环替换了显式递归版本。如果访问节点的数量已经太深(即limit
不是正数),这允许我停止递归。当我递交检查下一个孩子时,我减去前一个孩子访问的dfs
节点的数量,并将其设置为下一个孩子的限制。
功能语言很有趣,但它们是命令式编程的巨大飞跃。它确实让你关注 state 的概念,因为当你运行时,所有这些都在参数中非常明确。
编辑:稍微解释一下。
我最终从“仅打印叶子节点”(这是OP中的原始算法)转换为“打印所有节点”。这使我能够通过结果列表的长度访问子查询访问的节点数。如果你想坚持叶子节点,你必须携带你已经访问过多少个节点:
再次编辑要清除这个答案,我将所有Haskell代码放在ideone上,并且我已将我的Haskell代码音译为Scala,因此这可以留在这里作为明确答案问题:
case class Node[T](label:T, children:Seq[Node[T]])
case class TraversalResult[T](num_visited:Int, labels:Seq[T])
def dfs[T](node:Node[T], limit:Int):TraversalResult[T] =
limit match {
case 0 => TraversalResult(0, Nil)
case limit =>
node.children match {
case Nil => TraversalResult(1, List(node.label))
case children => {
val result = traverse(node.children, limit - 1)
TraversalResult(result.num_visited + 1, result.labels)
}
}
}
def traverse[T](children:Seq[Node[T]], limit:Int):TraversalResult[T] =
limit match {
case 0 => TraversalResult(0, Nil)
case limit =>
children match {
case Nil => TraversalResult(0, Nil)
case first :: rest => {
val trav_first = dfs(first, limit)
val trav_rest =
traverse(rest, limit - trav_first.num_visited)
TraversalResult(
trav_first.num_visited + trav_rest.num_visited,
trav_first.labels ++ trav_rest.labels
)
}
}
}
val n = Node(0, List(
Node(1, List(Node(2, Nil), Node(3, Nil))),
Node(4, List(Node(5, List(Node(6, Nil))))),
Node(7, Nil)
))
for (i <- 1 to 8)
println(dfs(n, i))
输出:
TraversalResult(1,List())
TraversalResult(2,List())
TraversalResult(3,List(2))
TraversalResult(4,List(2, 3))
TraversalResult(5,List(2, 3))
TraversalResult(6,List(2, 3))
TraversalResult(7,List(2, 3, 6))
TraversalResult(8,List(2, 3, 6, 7))
P.S。这是我第一次尝试Scala,所以上面可能包含一些可怕的非惯用代码。对不起。
答案 1 :(得分:4)
您可以通过传递索引或取尾来将广度转换为深度:
def suml(xs: List[Int], total: Int = 0) = xs match {
case Nil => total
case x :: rest => suml(rest, total+x)
}
def suma(xs: Array[Int], from: Int = 0, total: Int = 0) = {
if (from >= xs.length) total
else suma(xs, from+1, total + xs(from))
}
在后一种情况下,如果你愿意,你已经有了限制你的广度的东西;在前者中,只需添加width
或其他一些。
答案 2 :(得分:2)
以下实现了对树中节点的延迟深度优先搜索。
import collection.TraversableView
case class Node[T](label: T, ns: Node[T]*)
def dfs[T](r: Node[T]): TraversableView[Node[T], Traversable[Node[T]]] =
(Traversable[Node[T]](r).view /: r.ns) {
(a, b) => (a ++ dfs(b)).asInstanceOf[TraversableView[Node[T], Traversable[Node[T]]]]
}
以深度优先顺序打印所有节点的标签。
val r = Node('a, Node('b, Node('d), Node('e, Node('f))), Node('c))
dfs(r).map(_.label).force
// returns Traversable[Symbol] = List('a, 'b, 'd, 'e, 'f, 'c)
这是做同样的事情,在访问了3个节点后退出。
dfs(r).take(3).map(_.label).force
// returns Traversable[Symbol] = List('a, 'b, 'd)
如果您只想要叶节点,可以使用filter
,等等。
请注意dfs
函数的fold子句需要显式asInstanceOf
强制转换。有关Scala输入问题的讨论,请参阅"Type variance error in Scala when doing a foldLeft over Traversable views"。