当某些条件为真时,Scala foldLeft

时间:2018-12-11 16:45:22

标签: scala fold foldleft

如何在Scala中模拟以下行为?即在满足某些特定条件的情况下继续折叠。

def foldLeftWhile[B](z: B, p: B => Boolean)(op: (B, A) => B): B

例如

scala> val seq = Seq(1, 2, 3, 4)
seq: Seq[Int] = List(1, 2, 3, 4)
scala> seq.foldLeftWhile(0, _ < 3) { (acc, e) => acc + e }
res0: Int = 1
scala> seq.foldLeftWhile(0, _ < 7) { (acc, e) => acc + e }
res1: Int = 6

更新:

基于@Dima的答案,我意识到我的意图有些副作用。因此,我使它与takeWhile同步,即谓词不匹配将不会前进。并添加更多示例以使其更加清晰。 (注意:不适用于Iterator s)

6 个答案:

答案 0 :(得分:3)

首先,请注意您的示例似乎是错误的。如果我正确理解您的描述,那么结果应该是1(满足谓词_ < 3的最后一个值),而不是6

最简单的方法是使用return语句,该语句在scala中非常皱眉,但我想,出于完整性考虑,我会提到它。

def foldLeftWhile[A, B](seq: Seq[A], z: B, p: B => Boolean)(op: (B, A) => B): B = foldLeft(z) { case (b, a) => 
   val result = op(b, a) 
   if(!p(result)) return b
   result
}

由于我们要避免使用return,因此可能会使用scanLeft

seq.toStream.scanLeft(z)(op).takeWhile(p).last

这有点浪费,因为它会累积所有(匹配的)结果。 您可以使用iterator而不是toStream来避免这种情况,但是Iterator由于某种原因没有.last,因此,您需要多花一点时间浏览它明确地:

 seq.iterator.scanLeft(z)(op).takeWhile(p).foldLeft(z) { case (_, b) => b }

答案 1 :(得分:3)

在scala中定义想要的内容非常简单。您可以定义一个隐式类,将您的函数添加到任何TraversableOnce(包括Seq)中。

implicit class FoldLeftWhile[A](trav: TraversableOnce[A]) {
  def foldLeftWhile[B](init: B)(where: B => Boolean)(op: (B, A) => B): B = {
    trav.foldLeft(init)((acc, next) => if (where(acc)) op(acc, next) else acc)
  }
}
Seq(1,2,3,4).foldLeftWhile(0)(_ < 3)((acc, e) => acc + e)

更新,因为问题已修改:

implicit class FoldLeftWhile[A](trav: TraversableOnce[A]) {
  def foldLeftWhile[B](init: B)(where: B => Boolean)(op: (B, A) => B): B = {
    trav.foldLeft((init, false))((a,b) => if (a._2) a else {
      val r = op(a._1, b)
      if (where(r)) (op(a._1, b), false) else (a._1, true)
    })._1
  }
}

请注意,我将您的(z:B,p:B =>布尔值)拆分为两个高阶函数。那只是个人scala风格的喜好。

答案 2 :(得分:2)

那呢:

def foldLeftWhile[A, B](z: B, xs: Seq[A], p: B => Boolean)(op: (B, A) => B): B = {
  def go(acc: B, l: Seq[A]): B = l match {
    case h +: t => 
        val nacc = op(acc, h)
        if(p(nacc)) go(op(nacc, h), t) else nacc
    case _ => acc
  }
  go(z, xs)
}

val a = Seq(1,2,3,4,5,6)
val r = foldLeftWhile(0, a, (x: Int) => x <= 3)(_ + _)
println(s"$r")

在谓词为true时对集合进行递归迭代,然后返回累加器。

您可以在scalafiddle上尝试

答案 3 :(得分:1)

第一拳:每个元素的谓词

首先,您可以使用内部尾部递归函数

implicit class TravExt[A](seq: TraversableOnce[A]) {
  def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = {
    @tailrec
    def rec(trav: TraversableOnce[A], z: B): B = trav match {
      case head :: tail if f(head) => rec(tail, op(head, z))
      case _ => z
    }
    rec(seq, z)
  }
}

或简短版本

implicit class TravExt[A](seq: TraversableOnce[A]) {
  @tailrec
  final def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = seq match {
    case head :: tail if f(head) => tail.foldLeftWhile(op(head, z), f)(op)
    case _ => z
  }
}

然后使用

val a = List(1, 2, 3, 4, 5, 6).foldLeftWhile(0, _ < 3)(_ + _) //a == 3

第二个示例:累加器值:

implicit class TravExt[A](seq: TraversableOnce[A]) {
  def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = {
    @tailrec
    def rec(trav: TraversableOnce[A], z: B): B = trav match {
      case _ if !f(z) => z
      case head :: tail => rec(tail, op(head, z))
      case _ => z
    }
    rec(seq, z)
  }
}

或简短版本

implicit class TravExt[A](seq: TraversableOnce[A]) {
  @tailrec
  final def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = seq match {
    case _ if !f(z) => z
    case head :: tail => tail.foldLeftWhile(op(head, z), f)(op)
    case _ => z
  }
}

答案 4 :(得分:0)

一段时间后,我收到了很多好听的答案。所以,我将它们合并到了这篇文章中

@Dima的非常简洁的解决方案

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    seq.toStream.scanLeft(z)(op).takeWhile(p).lastOption.getOrElse(z)
  }
}

由@ElBaulP提供(我做了一些修改以匹配@Dima的评论)

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    @tailrec
    def foldLeftInternal(acc: B, seq: Seq[A]): B = seq match {
      case x :: _ =>
        val newAcc = op(acc, x)
        if (p(newAcc))
          foldLeftInternal(newAcc, seq.tail)
        else
          acc
      case _ => acc
    }

    foldLeftInternal(z, seq)
  }
}

我的答案(涉及副作用)

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    var accumulator = z
    seq
      .map { e =>
        accumulator = op(accumulator, e)
        accumulator -> e
      }
      .takeWhile { case (acc, _) =>
        p(acc)
      }
      .lastOption
      .map { case (acc, _) =>
        acc
      }
      .getOrElse(z)
  }
}

答案 5 :(得分:-1)

只需在累加器上使用分支条件:

seq.foldLeft(0, _ < 3) { (acc, e) => if (acc < 3) acc + e else acc}

但是,您将运行序列的每个条目。