在Scala中对无限流进行嵌套迭代

时间:2012-12-25 00:03:07

标签: scala stream for-comprehension

我有时发现自己想要在Scala for 理解中对无限流执行嵌套迭代,但是指定循环终止条件可能有点棘手。是否有更好的方法来做这种事情?

我想到的用例是我不一定知道我正在迭代的每个无限流中需要多少个元素(但显然我知道它不会是无穷大的数字) )。假设每个流的终止条件可能以某种复杂的方式依赖于for表达式中其他元素的值。

最初的想法是尝试将流终止条件写为 for 表达式中的 if 过滤器子句,但是当循环嵌套的无限流时会遇到麻烦,因为没有办法在第一个无限流上短路迭代,最终导致OutOfMemoryError。我理解为什么会出现这种情况,假设 表达式如何映射到 map flatMap withFilter 方法调用 - 我的问题是,做这种事情是否有更好的习惯用语(或许根本不涉及 理解)。

为了举例说明上述问题,请考虑以下(非常天真)代码来生成数字1和2的所有配对:

val pairs = for {
  i <- Stream.from(1) 
  if i < 3 
  j <- Stream.from(1) 
  if j < 3
} 
yield (i, j)

pairs.take(2).toList 
// result: List[(Int, Int)] = List((1,1), (1,2)) 

pairs.take(4).toList
// 'hoped for' result: List[(Int, Int)] = List((1,1), (1,2), (2,1), (2,2))
// actual result:
//  java.lang.OutOfMemoryError: Java heap space
//      at scala.collection.immutable.Stream$.from(Stream.scala:1105)

显然,在这个简单的示例中,可以通过将 if 过滤器移动到原始流上的 takeWhile 方法调用来轻松避免此问题,如下所示:

val pairs = for {
  i <- Stream.from(1).takeWhile(_ < 3) 
  j <- Stream.from(1).takeWhile(_ < 3) 
}    
yield (i, j)

但是出于问题的目的想象一个更复杂的用例,其中流终止条件不能轻易地移动到流表达式本身。

2 个答案:

答案 0 :(得分:2)

一种可能性是将Stream包装到您自己的处理filter的类中,在本例中为takeWhile

import scala.collection._
import scala.collection.generic._

class MyStream[+A]( val underlying: Stream[A] ) {
  def flatMap[B, That](f: (A) => GenTraversableOnce[B])(implicit bf: CanBuildFrom[Stream[A], B, That]): That = underlying.flatMap(f);

  def map[B, That](f: (A) ⇒ B)(implicit bf: CanBuildFrom[Stream[A], B, That]): That = underlying.map(f);

  def filter(p: A => Boolean): Stream[A] = underlying.takeWhile(p);
  //                                       ^^^^^^^^^^^^^^^^^^^^^^^^
}

object MyStream extends App {
  val pairs = for {
    i <- new MyStream(Stream.from(1))
    if i < 3
    j <- new MyStream(Stream.from(1))
    if j < 3
  } yield (i, j);

  print(pairs.toList);
}

这会打印List((1,1), (1,2), (2,1), (2,2))

答案 1 :(得分:0)

我已经采用了Petr的建议来提出我认为是一种更普遍可用的解决方案,因为它没有限制 if 过滤器在for comprehension中的定位(尽管它有更多的语法开销)。

我们的想法是再次将基础流封装在包装器对象中,该对象委托flatMapmapfilter方法而不进行修改,但首先应用takeWhile调用基础流,谓词为!isTruncated,其中isTruncated是属于包装器对象的字段。在任何时候在包装器对象上调用truncate将翻转isTruncated标志并有效地终止对流的进一步迭代。这很大程度上依赖于对基础流的takeWhile调用进行了懒惰评估的事实,因此在迭代的后期执行的代码可能会影响其行为。

缺点是您必须通过将 || s.truncate附加到过滤器表达式(其中s是对包装的引用来保留对您希望能够截断迭代中的流的引用)流)。您还需要确保在每次通过流的新迭代之前在包装器对象上调用reset(或使用新的包装器对象),除非您知道每次重复迭代的行为都相同。

import scala.collection._
import scala.collection.generic._

class TruncatableStream[A]( private val underlying: Stream[A]) {
  private var isTruncated = false;

  private var active = underlying.takeWhile(a => !isTruncated)

  def flatMap[B, That](f: (A) => GenTraversableOnce[B])(implicit bf: CanBuildFrom[Stream[A], B, That]): That = active.flatMap(f);

  def map[B, That](f: (A) => B)(implicit bf: CanBuildFrom[Stream[A], B, That]): That = active.map(f);

  def filter(p: A => Boolean): Stream[A] = active.filter(p);

  def truncate() = {
    isTruncated = true
    false
  }

  def reset() = {
    isTruncated = false
    active = underlying.takeWhile(a => !isTruncated)
  }
}

val s1 = new TruncatableStream(Stream.from(1))
val s2 = new TruncatableStream(Stream.from(1))

val pairs = for {
  i <- s1

  // reset the nested iteration at the start of each outer iteration loop 
  // (not strictly required here as the repeat iterations are all identical)
  // alternatively, could just write: s2 = new TruncatableStream(Stream.from(1))  
  _ = _s2.reset()      

  j <- s2
  if i < 3 || s1.truncate
  if j < 3 || s2.truncate
} 
yield (i, j)

pairs.take(2).toList  // res1: List[(Int, Int)] = List((1,1), (1,2))
pairs.take(4).toList  // res2: List[(Int, Int)] = List((1,1), (1,2), (2,1), (2,2))

毫无疑问,这可以改进,但似乎是一个合理的解决方案。