StackOverflow和其他地方有很多声明nth_element
O(n),并且通常使用Introselect实现:http://en.cppreference.com/w/cpp/algorithm/nth_element
我想知道如何实现这一目标。我看了Wikipedia's explanation of Introselect,这让我更加困惑。算法如何在QSort和Median-of-Medians之间切换?
我在这里找到了Introsort论文:http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.14.5196&rep=rep1&type=pdf但是那说:
在本文中,我们将集中讨论排序问题,并在后面的章节中简要回到选择问题。
我试图通过STL本身阅读以了解nth_element
是如何实现的,但这种方法很快就会变得毛茸茸。
有人可以向我展示如何实施Introselect的伪代码吗?或者甚至更好,当然除了STL之外的实际C ++代码:)
答案 0 :(得分:13)
你问了两个问题,那个名义上的问题
nth_element是如何实现的?
您已经回答:
StackOverflow和其他地方有很多声明nth_element是O(n),并且通常使用Introselect实现。
我也可以通过查看我的stdlib实现来确认。 (稍后会详细介绍。)
那是你不理解答案的那个:
算法如何在QSort和Median-of-Medians之间切换?
让我们看看我从stdlib中提取的伪代码:
nth_element(first, nth, last)
{
if (first == last || nth == last)
return;
introselect(first, nth, last, log2(last - first) * 2);
}
introselect(first, nth, last, depth_limit)
{
while (last - first > 3)
{
if (depth_limit == 0)
{
// [NOTE by editor] This should be median-of-medians instead.
// [NOTE by editor] See Azmisov's comment below
heap_select(first, nth + 1, last);
// Place the nth largest element in its final position.
iter_swap(first, nth);
return;
}
--depth_limit;
cut = unguarded_partition_pivot(first, last);
if (cut <= nth)
first = cut;
else
last = cut;
}
insertion_sort(first, last);
}
在不详细了解引用函数heap_select
和unguarded_partition_pivot
的情况下,我们可以清楚地看到,nth_element
给出了内部选择2 * log2(size)
细分步骤(快速选择所需的两倍)在最好的情况下)直到heap_select
开始并解决问题。
答案 1 :(得分:10)
免责声明:我不知道在任何标准库中如何实现std::nth_element
。
如果您了解Quicksort的工作原理,您可以轻松修改它以执行此算法所需的操作。 Quicksort的基本思想是,在每个步骤中,您将数组分成两部分,使得小于枢轴的所有元素都在左子数组中,并且所有等于或大于数据透视的元素都在右子数组中。 (称为三元Quicksort的Quicksort的修改创建了第三个子阵列,所有元素都等于pivot。然后右子阵列只包含严格大于pivot的条目。)然后Quicksort通过递归排序左右子进行-arrays。
如果您只想将 n -th元素移动到位,而不是递归到两个子阵列,您可以在每个步骤中告诉您是否需要下降到左或右子阵列。 (你知道这是因为排序数组中的 n -th元素具有索引 n 所以它变成了比较索引的问题。)所以 - 除非你的Quicksort遭遇最坏情况退化 - 你大致将每一步中剩余阵列的大小减半。 (你永远不会再看另一个子数组。)因此,平均而言,你在每一步中处理以下长度的数组:
每个步骤在它所处理的数组的长度上是线性的。 (你循环一遍,决定每个元素应该去哪个子数组,具体取决于它与数据透视的比较。)
你可以看到在Θ(log( N ))步骤之后,我们最终会到达单个数组并完成。如果总结 N (1 + 1/2 + 1/4 + ...),你将获得2 N 。或者,在一般情况下,因为我们不能希望枢轴总是完全是中位数,所以大概是Θ( N )。
答案 2 :(得分:8)
STL(版本3.3,我认为)的代码是:
template <class _RandomAccessIter, class _Tp>
void __nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
_RandomAccessIter __last, _Tp*) {
while (__last - __first > 3) {
_RandomAccessIter __cut =
__unguarded_partition(__first, __last,
_Tp(__median(*__first,
*(__first + (__last - __first)/2),
*(__last - 1))));
if (__cut <= __nth)
__first = __cut;
else
__last = __cut;
}
__insertion_sort(__first, __last);
}
让我们简化一下:
template <class Iter, class T>
void nth_element(Iter first, Iter nth, Iter last) {
while (last - first > 3) {
Iter cut =
unguarded_partition(first, last,
T(median(*first,
*(first + (last - first)/2),
*(last - 1))));
if (cut <= nth)
first = cut;
else
last = cut;
}
insertion_sort(first, last);
}
我在这里做的是删除双下划线和_Uppercase东西,这只是为了保护代码免受用户可以合法定义为宏的内容。我还删除了最后一个参数,它只能用于模板类型推导,并为简洁起见重命名了迭代器类型。
正如您现在应该看到的那样,它会重复划分范围,直到剩余范围内剩余少于四个元素,然后进行简单排序。
现在,为什么O(n)?首先,最多三个元素的最终排序是O(1),因为最多有三个元素。现在,剩下的就是重复分区。 O(n)中的分区本身就是O(n)。在这里,每个步骤减少了下一步需要触摸的元素数量,因此你有O(n)+ O(n / 2)+ O(n / 4)+ O(n / 8)这是如果总结的话,小于O(2n)。由于O(2n)= O(n),因此平均具有线性复杂度。