遍历二叉树时构建不可变集合

时间:2017-03-06 00:51:16

标签: scala

我正在使用下面的类MessageTree来表示包含Message实例的二叉树,我想实现一个返回新MessageTree的方法filter包含满足给定谓词的所有消息。我有一个使用ListBuffer的解决方案,我正在寻找一个解决方案,在遍历树时构建一个不可变的集合,而不是ListBuffer。任何见解都表示赞赏。

object scratchpad extends App {

  import scala.collection.mutable.ListBuffer

  // A class to represent messages
  class Message(val text: String, val user: String, val comments: Int) {
    override def toString: String =
      "User: " + user + "\n" +
        "Text: " + text + " [" + comments + " comments]"
  }

  /* --------------------------------------------------------------------- */

  abstract class MessageTree {

    type MessagePredicate = Message => Boolean

    def traverse(p: MessagePredicate, storage: ListBuffer[Message]) : Unit

    def filter(p: Message => Boolean): MessageTree = filterAcc(p, new Empty)

    def filterAcc(p: Message => Boolean, acc: MessageTree): MessageTree

    def incl(tweet: Message): MessageTree

    def foreach(f: Message => Unit): Unit

  }

  /* --------------------------------------------------------------------- */

  class Empty extends MessageTree {

    override def traverse(p: MessagePredicate, storage: ListBuffer[Message]): Unit = {}

    override def filterAcc(p: (Message) => Boolean, acc: MessageTree): MessageTree = acc

    def incl(tweet: Message): MessageTree = new NonEmpty(tweet, new Empty, new Empty)

    def foreach(f: Message => Unit): Unit = ()

  }

  /* --------------------------------------------------------------------- */

  class NonEmpty(elem: Message, left: MessageTree, right: MessageTree) extends MessageTree {

    override def traverse(p: MessagePredicate, storage: ListBuffer[Message]): Unit = {
      left.traverse(p, storage)
      if (p(elem)) storage += elem
      right.traverse(p, storage)
    }

    override def filterAcc(p: (Message) => Boolean, acc: MessageTree): MessageTree = {
      val tweet_collector = ListBuffer.empty[Message]
      traverse(p, tweet_collector)
      def loop(listBuffer: ListBuffer[Message], accum: MessageTree) : MessageTree = {
        if (listBuffer.isEmpty) accum
        else loop(listBuffer.tail, accum incl listBuffer.head)
      }
      loop(tweet_collector, acc)
    }

    def incl(x: Message): MessageTree = {
      if (x.text < elem.text) new NonEmpty(elem, left.incl(x), right)
      else if (elem.text < x.text) new NonEmpty(elem, left, right.incl(x))
      else this
    }

    def foreach(f: Message => Unit): Unit = {
      left.foreach(f)
      f(elem)
      right.foreach(f)
    }

  }

  /* --------------------------------------------------------------------- */
  /* Test
  /* --------------------------------------------------------------------- */
  val keyword_list: List[String] = "one" :: "three" :: "five" :: "seven" :: "nine" :: Nil

  def search_predicate(msg: Message): Boolean = {
    keyword_list.exists(msg.text.contains(_))
  }

  val msg_00 = new Message("zero", "John_00", 50)
  val msg_01 = new Message("one", "John_01", 10)
  val msg_02 = new Message("two", "John_02", 20)
  val msg_03 = new Message("three", "John_03", 30)
  val msg_04 = new Message("four", "John_04", 40)
  val msg_05 = new Message("five", "John_05", 50)

  val tmp_message_tree = new NonEmpty(msg_00, new Empty, new Empty)
  val message_tree = tmp_message_tree incl msg_01 incl msg_02 incl msg_03 incl msg_04 incl msg_05

  println("All messages")
  message_tree foreach println

  println("Filtered messages")
  (message_tree filter search_predicate) foreach println


}

输出

  

所有消息

     

用户:John_05
  文字:五篇[50条评论]
  用户:John_04
  文字:四[40评论]
  用户:John_01
  文字:一篇[10条评论]
  用户:John_03
  文字:三篇[30条评论]
  用户:John_02
  文字:两篇[20条评论]
  用户:John_00
  文字:零[50条评论]

     

已过滤的讯息

     

用户:John_05
  文字:五篇[50条评论]
  用户:John_01
  文字:一篇[10条评论]
  用户:John_03
  文字:三篇[30条评论]

1 个答案:

答案 0 :(得分:1)

你几乎就在那里:你只需像这样改变遍历:

def traverse(p: MessagePredicate):List[Message] = {
  left.traverse(p) ++ 
  (if (p(elem)) List(elem) else List()) ++
  right.traverse(p)
}  

这个解决方案很简单,但它使用了太多的列表连接,这是昂贵的。另一种方法是使用累加器来减少连接:

def traverse(p: MessagePredicate, acc:List[Message]):List[Message] = {
  val r = right.traverse(acc)
  left.traverse(p, if(p (elem)) elem::r else r)
}

和空节点:

def traverse(p: MessagePredicate, acc:List[Message]):List[Message] = acc