获取Scala Iterable的前n个元素的最简单方法

时间:2011-04-15 09:21:27

标签: algorithm scala

是否有一个简单而有效的解决方案来确定Scala Iterable的前n个元素?我的意思是

iter.toList.sortBy(_.myAttr).take(2)

但是当只有前2个感兴趣时,无需对所有元素进行排序。理想情况下,我正在寻找像

这样的东西
iter.top(2, _.myAttr)

另请参阅:使用订购的顶部元素的解决方案:In Scala, how to use Ordering[T] with List.min or List.max and keep code readable

更新

谢谢大家的解决方案。最后,我采用用户未知的原始解决方案并将其用于Iterable pimp-my-library 模式:

implicit def iterExt[A](iter: Iterable[A]) = new {
  def top[B](n: Int, f: A => B)(implicit ord: Ordering[B]): List[A] = {
    def updateSofar (sofar: List [A], el: A): List [A] = {
      //println (el + " - " + sofar)

      if (ord.compare(f(el), f(sofar.head)) > 0)
        (el :: sofar.tail).sortBy (f)
      else sofar
    }

    val (sofar, rest) = iter.splitAt(n)
    (sofar.toList.sortBy (f) /: rest) (updateSofar (_, _)).reverse
  }
}

case class A(s: String, i: Int)
val li = List (4, 3, 6, 7, 1, 2, 9, 5).map(i => A(i.toString(), i))
println(li.top(3, _.i))

9 个答案:

答案 0 :(得分:19)

我的解决方案(绑定到Int,但应该很容易更改为Ordered(请几分钟):

def top (n: Int, li: List [Int]) : List[Int] = {

  def updateSofar (sofar: List [Int], el: Int) : List [Int] = {
    // println (el + " - " + sofar)
    if (el < sofar.head) 
      (el :: sofar.tail).sortWith (_ > _) 
    else sofar
  }

  /* better readable:
    val sofar = li.take (n).sortWith (_ > _)
    val rest = li.drop (n)
    (sofar /: rest) (updateSofar (_, _)) */    
  (li.take (n). sortWith (_ > _) /: li.drop (n)) (updateSofar (_, _)) 
}

用法:

val li = List (4, 3, 6, 7, 1, 2, 9, 5)    
top (2, li)
  • 对于上面的列表,将前2(4,3)作为起始TopTen(TopTwo)。
  • 对它们进行排序,使得第一个元素是较大的元素(如果有的话)。
  • 重复遍历列表的其余部分(li.drop(n)),并将当前元素与最小值列表的最大值进行比较;如有必要,请更换,并再次使用。
  • 改进:
    • 扔掉Int,并使用有序。
    • 扔掉(_&gt; _)并使用用户排序以允许BottomTen。 (更难:选中间10 :))
    • 扔掉List,然后使用Iterable

更新(抽象):

def extremeN [T](n: Int, li: List [T])
  (comp1: ((T, T) => Boolean), comp2: ((T, T) => Boolean)):
     List[T] = {

  def updateSofar (sofar: List [T], el: T) : List [T] =
    if (comp1 (el, sofar.head)) 
      (el :: sofar.tail).sortWith (comp2 (_, _)) 
    else sofar

  (li.take (n) .sortWith (comp2 (_, _)) /: li.drop (n)) (updateSofar (_, _)) 
}

/*  still bound to Int:  
def top (n: Int, li: List [Int]) : List[Int] = {
  extremeN (n, li) ((_ < _), (_ > _))
}
def bottom (n: Int, li: List [Int]) : List[Int] = {
  extremeN (n, li) ((_ > _), (_ < _))
}
*/

def top [T] (n: Int, li: List [T]) 
  (implicit ord: Ordering[T]): Iterable[T] = {
  extremeN (n, li) (ord.lt (_, _), ord.gt (_, _))
}
def bottom [T] (n: Int, li: List [T])
  (implicit ord: Ordering[T]): Iterable[T] = {
  extremeN (n, li) (ord.gt (_, _), ord.lt (_, _))
}

top (3, li)
bottom (3, li)
val sl = List ("Haus", "Garten", "Boot", "Sumpf", "X", "y", "xkcd", "x11")
bottom (2, sl)

用Iterable替换List似乎有点困难。

正如Daniel C. Sobral在评论中指出的那样,topN中的高n可以导致很多排序工作,因此可以使用手动插入排序而不是重复排序整个列表前n个元素:

def extremeN [T](n: Int, li: List [T])
  (comp1: ((T, T) => Boolean), comp2: ((T, T) => Boolean)):
     List[T] = {

  def sortedIns (el: T, list: List[T]): List[T] = 
    if (list.isEmpty) List (el) else 
    if (comp2 (el, list.head)) el :: list else 
      list.head :: sortedIns (el, list.tail)

  def updateSofar (sofar: List [T], el: T) : List [T] =
    if (comp1 (el, sofar.head)) 
      sortedIns (el, sofar.tail)
    else sofar

  (li.take (n) .sortWith (comp2 (_, _)) /: li.drop (n)) (updateSofar (_, _)) 
}

上/下方法和用法如上。对于小组的顶部/底部元素,很少调用排序,在开始时几次,然后随着时间的推移越来越少。例如,70次顶部(10)的10 000次,90次顶部(10次)10万次。

答案 1 :(得分:7)

又一个版本:

val big = (1 to 100000)

def maxes[A](n:Int)(l:Traversable[A])(implicit o:Ordering[A]) =
    l.foldLeft(collection.immutable.SortedSet.empty[A]) { (xs,y) =>
      if (xs.size < n) xs + y
      else {
        import o._
        val first = xs.firstKey
        if (first < y) xs - first + y
        else xs
      }
    }

println(maxes(4)(big))
println(maxes(2)(List("a","ab","c","z")))

使用Set强制列表具有唯一值:

def maxes2[A](n:Int)(l:Traversable[A])(implicit o:Ordering[A]) =
    l.foldLeft(List.empty[A]) { (xs,y) =>
      import o._
      if (xs.size < n) (y::xs).sort(lt _)
      else {
        val first = xs.head
        if (first < y) (y::(xs - first)).sort(lt _)
        else xs
      }
    }

答案 2 :(得分:6)

这是另一种简单且性能相当不错的解决方案。

def pickTopN[T](k: Int, iterable: Iterable[T])(implicit ord: Ordering[T]): Seq[T] {
  val q = collection.mutable.PriorityQueue[T](iterable.toSeq:_*)
  val end = Math.min(k, q.size)
  (1 to end).map(_ => q.dequeue())
}

大O是O(n + k log n),其中k <= n。因此,对于小k和最差n log n,性能是线性的。

该解决方案还可以针对内存优化为O(k),但针对性能优化O(n log k)。我们的想法是使用MinHeap始终只跟踪前k项。这是解决方案。

def pickTopN[A, B](n: Int, iterable: Iterable[A], f: A => B)(implicit ord: Ordering[B]): Seq[A] = {
  val seq = iterable.toSeq
  val q = collection.mutable.PriorityQueue[A](seq.take(n):_*)(ord.on(f).reverse) // initialize with first n

  // invariant: keep the top k scanned so far
  seq.drop(n).foreach(v => {
    q += v
    q.dequeue()
  })

  q.dequeueAll.reverse
}

答案 3 :(得分:4)

您无需对整个集合进行排序,以确定前N个元素。但是,我不相信这个功能是由原始库提供的,所以你必须自己动手,可能使用 pimp-my-library 模式。

例如,您可以按如下方式获取集合的第n个元素:

  class Pimp[A, Repr <% TraversableLike[A, Repr]](self : Repr) {

    def nth(n : Int)(implicit ord : Ordering[A]) : A = {
      val trav : TraversableLike[A, Repr] = self
      var ltp : List[A] = Nil
      var etp : List[A] = Nil
      var mtp : List[A] = Nil
      trav.headOption match {
        case None      => error("Cannot get " + n + " element of empty collection")
        case Some(piv) =>
          trav.foreach { a =>
            val cf = ord.compare(piv, a)
            if (cf == 0) etp ::= a
            else if (cf > 0) ltp ::= a
            else mtp ::= a
          }
          if (n < ltp.length)
            new Pimp[A, List[A]](ltp.reverse).nth(n)(ord)
          else if (n < (ltp.length + etp.length))
            piv
          else
            new Pimp[A, List[A]](mtp.reverse).nth(n - ltp.length - etp.length)(ord)
      }
    }
  }

(这不是很实用;对不起)

获得最高n元素是非常简单的:

def topN(n : Int)(implicit ord : Ordering[A], bf : CanBuildFrom[Repr, A, Repr]) ={
  val b = bf()
  val elem = new Pimp[A, Repr](self).nth(n)(ord)
  import util.control.Breaks._
  breakable {
    var soFar = 0
    self.foreach { tt =>
      if (ord.compare(tt, elem) < 0) {
         b += tt
         soFar += 1
      }
    }
    assert (soFar <= n)
    if (soFar < n) {
      self.foreach { tt =>
        if (ord.compare(tt, elem) == 0) {
          b += tt
          soFar += 1
        }
        if (soFar == n) break
      }
    }

  }
  b.result()
}

不幸的是,我很难通过这个暗示来发现这个皮条客:

implicit def t2n[A, Repr <% TraversableLike[A, Repr]](t : Repr) : Pimp[A, Repr] 
  = new Pimp[A, Repr](t)

我明白了:

scala> List(4, 3, 6, 7, 1, 2, 8, 5).topN(4)
<console>:9: error: could not find implicit value for evidence parameter of type (List[Int]) => scala.collection.TraversableLike[A,List[Int]]
   List(4, 3, 6, 7, 1, 2, 8, 5).topN(4)
       ^

但是,代码实际上可以正常工作:

scala> new Pimp(List(4, 3, 6, 7, 1, 2, 8, 5)).topN(4)
res3: List[Int] = List(3, 1, 2, 4)

scala> new Pimp("ioanusdhpisjdmpsdsvfgewqw").topN(6)
res2: java.lang.String = adddfe

答案 4 :(得分:2)

如果目标是不对整个列表进行排序,那么你可以做这样的事情(当然它可以优化一点,以便我们不会在数字明显不应该出现时更改列表):< / p>

List(1,6,3,7,3,2).foldLeft(List[Int]()){(l, n) => (n :: l).sorted.take(2)}

答案 5 :(得分:1)

我最近在Apache Jackrabbit的Rank类中实现了这样的排名算法(尽管在Java中)。请参阅take方法获取它的要点。基本思路是快速排序,但一旦找到顶级n元素,就会过早终止。

答案 6 :(得分:1)

这是渐近 O(n)解决方案。

def top[T](data: List[T], n: Int)(implicit ord: Ordering[T]): List[T] = {
    require( n < data.size)

    def partition_inner(shuffledData: List[T], pivot: T): List[T] = 
      shuffledData.partition( e => ord.compare(e, pivot) > 0 ) match {
          case (left, right) if left.size == n => left
          case (left, x :: rest) if left.size < n => 
            partition_inner(util.Random.shuffle(data), x)
          case (left @ y :: rest, right) if left.size > n => 
            partition_inner(util.Random.shuffle(data), y)
      }

     val shuffled = util.Random.shuffle(data)
     partition_inner(shuffled, shuffled.head)
}

scala> top(List.range(1,10000000), 5)

由于递归,此解决方案将比上面的一些非线性解决方案花费更长时间并且可能导致java.lang.OutOfMemoryError: GC overhead limit exceeded。 但稍微更具可读性的恕我直言和功能风格。仅供求职面试;)。

更重要的是,此解决方案可以轻松并行化。

def top[T](data: List[T], n: Int)(implicit ord: Ordering[T]): List[T] = {
    require( n < data.size)

    @tailrec
    def partition_inner(shuffledData: List[T], pivot: T): List[T] = 
      shuffledData.par.partition( e => ord.compare(e, pivot) > 0 ) match {
          case (left, right) if left.size == n => left.toList
          case (left, right) if left.size < n => 
            partition_inner(util.Random.shuffle(data), right.head)
          case (left, right) if left.size > n => 
            partition_inner(util.Random.shuffle(data), left.head)
      }

     val shuffled = util.Random.shuffle(data)
     partition_inner(shuffled, shuffled.head)
}

答案 7 :(得分:0)

对于n和大型列表的较小值,可以通过选择最大元素n来实现获取最高n元素:

def top[T](n:Int, iter:Iterable[T])(implicit ord: Ordering[T]): Iterable[T] = {
  def partitionMax(acc: Iterable[T], it: Iterable[T]): Iterable[T]  = {
    val max = it.max(ord)
    val (nextElems, rest) = it.partition(ord.gteq(_, max))
    val maxElems = acc ++ nextElems
    if (maxElems.size >= n || rest.isEmpty) maxElems.take(n)
    else partitionMax(maxElems, rest)
  }
  if (iter.isEmpty) iter.take(0)
  else partitionMax(iter.take(0), iter)
}

这不会对整个列表进行排序,而是采用Ordering。我相信我在partitionMax中调用的每个方法都是O(列表大小),我只希望最多称它为n次,因此小n的整体效率将与迭代器的大小。

scala> top(5, List.range(1,1000000))
res13: Iterable[Int] = List(999999, 999998, 999997, 999996, 999995)

scala> top(5, List.range(1,1000000))(Ordering[Int].on(- _))
res14: Iterable[Int] = List(1, 2, 3, 4, 5)

您还可以在n接近可迭代的大小时添加分支,然后切换到iter.toList.sortBy(_.myAttr).take(n)

它不会返回所提供的集合类型,但如果需要,您可以查看How do I apply the enrich-my-library pattern to Scala collections?

答案 8 :(得分:0)

使用时间复杂度为O(nlogk)的{​​{1}}的优化解决方案。在更新中给出的方法中,您每次都会对sofar列表进行排序,而不是需要使用PriorityQueue进行优化。

import scala.language.implicitConversions
import scala.language.reflectiveCalls
import collection.mutable.PriorityQueue
implicit def iterExt[A](iter: Iterable[A]) = new {
    def top[B](n: Int, f: A => B)(implicit ord: Ordering[B]) : List[A] = {
        def updateSofar (sofar: PriorityQueue[A], el: A): PriorityQueue[A] = {
            if (ord.compare(f(el), f(sofar.head)) < 0){
                sofar.dequeue
                sofar.enqueue(el)
            }
            sofar
        }

        val (sofar, rest) = iter.splitAt(n)
        (PriorityQueue(sofar.toSeq:_*)( Ordering.by( (x :A) => f(x) ) ) /: rest) (updateSofar (_, _)).dequeueAll.toList.reverse
    }
}

case class A(s: String, i: Int)
val li = List (4, 3, 6, 7, 1, 2, 9, 5).map(i => A(i.toString(), i))
println(li.top(3, -_.i))