首先,我想说这是一项学校作业,我只是寻求一些指导。
我的任务是编写一个算法,使用quickselect查找seq中的第k个最小元素。这应该很容易,但是当我进行一些测试时,我碰壁了。出于某种原因,如果我使用输入(List(1, 1, 1, 1), 1)
,它将进入无限循环。
这是我的实施:
val rand = new scala.util.Random()
def find(seq: Seq[Int], k: Int): Int = {
require(0 <= k && k < seq.length)
val a: Array[Int] = seq.toArray[Int] // Can't modify the argument sequence
val pivot = rand.nextInt(a.length)
val (low, high) = a.partition(_ < a(pivot))
if (low.length == k) a(pivot)
else if (low.length < k) find(high, k - low.length)
else find(low, k)
}
出于某种原因(或因为我累了)我无法发现我的错误。如果有人在我出错的地方暗示我,我会很高兴的。
答案 0 :(得分:1)
基本上你依赖这一行 - val (low, high) = a.partition(_ < a(pivot))
将数组拆分成2个数组。第一个包含小于pivot-element的连续元素序列,第二个包含其余元素。
然后你说如果第一个数组的长度为k
,那意味着你已经看到k
元素比你的pivot元素小。这意味着pivot-element实际上是k+1
最小的,你实际上是返回k+1
个最小元素而不是k
th。这是你的第一个错误。
另外......当你的所有元素都相同时会出现更大的问题,因为你的第一个数组总是有0个元素。
不仅如此......您的代码会为您提供错误的答案输入,其中k
最小的元素包含重复元素,例如 - (1, 3, 4, 1, 2)
。
解决方案在于,在序列(1,1,1,1)中,4
最小元素是4
th 1
。这意味着您必须使用<=
而不是<
。
另外......由于partition
函数在boolean
条件为假之前不会拆分数组,因此无法使用分区来实现此数组拆分。你必须自己写分裂。