如何使我的快速选择算法更快

时间:2018-09-21 08:08:37

标签: algorithm scala

我参加了一门课程,研究算法,并且我们分配了一个用quickselect进行经典k:th最小元素(0是最小元素,sequence.length-1是最大元素)的任务。

算法平均应比排序方法快2 *

Arrays.sort

我的算法有效,但速度不够快。平均比上面的数组排序方法慢5倍。到目前为止,这是我的实现:

  def find(sequence: Seq[Int], k: Int): Int = {
    require(0 <= k && k < sequence.length)
    val a: Array[Int] = sequence.toArray[Int]
    select(a,k)
  }

  def select(a: Array[Int], k: Int): Int = {
    val pivot = rand.nextInt(a.length)     
    val (low, middle, high) = partition(a,a(pivot))
    if (low.length == k) a(pivot)
    else if(low.length > k) select(low, k) 
    else if (low.length + middle.length >= k+1) middle(0)
    else if (low.length == 0) select(high, k - low.length-middle.length)
    else  findFast(high, k - low.length-middle.length)
  }

  def partition(array: Array[Int],pivot: Int): (Array[Int],Array[Int],Array[Int])={
    (array.filter(_<pivot),array.filter(_==pivot),array.filter(_>pivot))
  }

您能给我一些技巧来改善实现的运行时间吗?

2 个答案:

答案 0 :(得分:1)

在您的实现中,partition函数执行array.filter三次。

为避免这种情况,您可以将Scala partition方法用作Rosettacode shows-请注意,代码不会两次执行partition(不知道实际运行时间)

import scala.util.Random


object QuickSelect {
  def quickSelect[A <% Ordered[A]](seq: Seq[A], n: Int, rand: Random = new Random): A = {
    val pivot = rand.nextInt(seq.length);
    val (left, right) = seq.partition(_ < seq(pivot))
    if (left.length == n) {
      seq(pivot)
    } else if (left.length < n) {
      quickSelect(right, n - left.length, rand)
    } else {
      quickSelect(left, n, rand)
    }
  }

  def main(args: Array[String]): Unit = {
    val v = Array(9, 8, 7, 6, 5, 0, 1, 2, 3, 4)
    println((0 until v.length).map(quickSelect(v, _)).mkString(", "))
  }
}

或在Scala中实现经典的Hoare或Lomuto分区。

algorithm partition(A, lo, hi) is
    pivot := A[lo]
    i := lo - 1
    j := hi + 1
    loop forever
        do
            i := i + 1
        while A[i] < pivot
        do
            j := j - 1
        while A[j] > pivot
        if i >= j then
            return j
        swap A[i] with A[j]

请注意,这里的工作在相同的数组/序列中进行(就地方法)-是否适合Scala(可变性等)?如果不适用-只需遍历序列,将小项目写入low序列,将大项目写入high序列。伪代码:

def partition(A,low,equal,high, pivot):
   for item in A:
      if item < pivot:
          low[lowidx++] = item
      elif item > pivot:
          high[highidx++] = item
      else:
          equal[eqidx++] = item

(虽然实际上不需要equal部分-您可以从其他长度中获取它的长度)

答案 1 :(得分:1)

虽然quickSelect 上有很多文章,并且伪代码的标准模板已经足够好,但我总是很难在 Scala 中找到好的实现,而 Rosetta 博客中的代码很好,但可能会进入无限循环,尤其是当对左分区数组或右分区数组进行排序。

以下是使用 Scala 稍加修改的有效解决方案(涵盖重复元素和排序数组等情况)。

如果输入数组排序,基本上返回第k个元素

def isSorted[T](arr: List[T])(implicit ord: Ordering[T]): Boolean = arr match {
        case Nil => true
        case x :: Nil => true
        case x :: xs => ord.lteq(x, xs.head) && isSorted(xs)
}
def quickSelect(nums: List[Int], k: Int): Int = {
        // if the input array is sorted then no point partitioning further
        // and go into a potential infinite loop even with random pivot
        // logical to pick the kth element from sorted array
        if (isSorted(nums)) return nums(k)
        // else start the partition logic
        val pvt = (new scala.util.Random).nextInt(nums.length)
        val (lower, higher) = nums.partition( _ < nums(pvt))
        if (lower.length > k) quickSelect(lower, k)
        else if (lower.length < k) quickSelect(higher, k - lower.length)
        else nums(pvt)
}

希望这会有所帮助。