Scala:使用Ordering特性对比较进行了错误评估

时间:2019-05-23 10:38:28

标签: scala quicksort implicit

我有以下SampleSort的实现:

import scala.reflect.ClassTag
import ca.vgorcinschi.ArrayOps

import Ordered._

//noinspection SpellCheckingInspection
class SampleSort[T: ClassTag : Ordering](val sampleSize: Int = 30) extends QuickSort[T] {

  import SearchTree._

  override def sort(a: Array[T]): Array[T] = {
    require(a != null, "Passed-in array should not be null")
    sortHelper(a)
  }

  private def sortHelper(a: Array[T]): Array[T] = {
    //if the array is shorter then the sampling - sort it with Quicksort
    if (a.length <= sampleSize) return super.sort(a)

    /*
      just the indices for the sample array.
      also required later for figuring out the nonPartitionedRemainder of the array
     */
    val sampleArrayIndices: Array[Int] = a.subArrayOfSize(sampleSize)
    val sampleArray: Array[T] = sampleArrayIndices map (a(_))

    val sortedSampleArray: Array[T] = sort(sampleArray, 0, sampleArray.length - 1)
    val searchTree: SearchTree = buildTree(sortedSampleArray, sampleSize / 2)
    val nonPartitionedRemainder = a.slice(0, sampleArrayIndices.head) ++ a.slice(sampleArrayIndices.last + 1, a.length)
    val finalTree = (searchTree /: nonPartitionedRemainder) (_ nest _)
    finalTree.arrays() flatMap sort
  }

  private class SearchTree(lt: Array[T], median: Array[T], gt: Array[T]) {
    //hear median is guaranteed to be non null and non empty based off the partitioning in sortHelper
    private val pivot: T = median.head

    def nest(value: T): SearchTree = {
      if (value < pivot) SearchTree(lt :+ value, median, gt)
      if (value > pivot) SearchTree(lt, median, gt :+ value)
      else SearchTree(lt, median :+ value, gt)
    }

    def arrays(): Array[Array[T]] = Array(lt, median, gt)
  }

  private object SearchTree {
    def buildTree(sample: Array[T], pivot: Int): SearchTree = {
      //do not look beyond pivot since sample is guaranteed to be partitioned
      val lt = sample.takeWhile(_ < sample(pivot))
      //only look from pivot and up
      val medianAndGt: (Array[T], Array[T]) = sample.slice(lt.length, sample.length) partition (_ == sample(pivot))
      SearchTree(lt, medianAndGt._1, medianAndGt._2)
    }

    def apply(lt: Array[T], median: Array[T], gt: Array[T]): SearchTree = new SearchTree(lt, median, gt)
  }

}

简而言之,这段代码的作用:

  1. 排序传入数组的样本
  2. 将值lt,eq或gt放入相应的存储段
  3. 在这些存储分区之一中分配数组的未排序部分
  4. 递归重复

这目前在SearchTree.nest方法中失败(上面的第3点),因为所有值都进入了中位数(eq)存储桶中:

enter image description here

然而,使用相同的SearchTree.buildTree操作,在import Ordered._对象函数中也可以进行类似的比较!

我不确定我在这里想念什么。在此问题上,我将不胜感激。

1 个答案:

答案 0 :(得分:1)

您在else之前缺少if (value > pivot)。您当前在nest中的代码如下:

  1. 如果是value < pivot,请构建一个新的SearchTree并扔掉;

  2. 如果value > pivot ...

因此,当value < pivot成立时,您将获得第二个else的{​​{1}}分支。