我没有在任何地方找到这个特定主题......
我在23个整数的std :: vector中的不同数据上调用nth_element()算法,每秒大约400,000次,更精确的“无符号短”值。
我想提高计算速度,这个特定的调用需要很大一部分CPU时间。 现在我注意到,与std :: sort()一样,即使具有最高优化级别和NDEBUG模式(Linux Clang编译器),nth_element函数在探查器中也是可见的,因此比较是内联的而不是函数调用本身。好吧,更多的前提:不是nth_element()但是std :: __ introselect()是可见的。
由于数据的大小很小,我尝试使用二次排序函数PIKSORT,当数据大小小于20个元素时,它通常比调用std :: sort更快,可能是因为函数将是内联的
template <class CONTAINER>
inline void piksort(CONTAINER& arr) // indeed this is "insertion sort"
{
typename CONTAINER::value_type a;
const int n = (int)arr.size();
for (int j = 1; j<n; ++j) {
a = arr[j];
int i = j;
while (i > 0 && a < arr[i - 1]) {
arr[i] = arr[i - 1];
i--;
}
arr[i] = a;
}
}
然而,这比在这种情况下使用nth_element慢。
另外,使用统计方法不合适,Something faster than std::nth_element
最后,由于值在0到约20000的范围内,因此直方图方法看起来不合适。
我的问题:有没有人知道一个简单的解决方案?我想我可能不是唯一一个经常调用std :: sort或nth_element的人。
答案 0 :(得分:13)
您提到数组的大小始终为23
。此外,使用的类型是unsigned short
。在这种情况下,您可能会尝试使用大小为23
的排序网络;由于您的类型为unsigned short
,因此使用排序网络对整个数组进行排序可能比使用std::nth_element
对其进行部分排序更快。下面是一个非常直接的C ++ 14实现的大小为23
且具有118
比较交换单元的排序网络,如Using Symmetry and Evolutionary Search to Minimize Sorting Networks所述:
template<typename RandomIt, typename Compare = std::less<>>
void network_sort23(RandomIt first, Compare compare={})
{
swap_if(first[1u], first[20u], compare);
swap_if(first[2u], first[21u], compare);
swap_if(first[5u], first[13u], compare);
swap_if(first[9u], first[17u], compare);
swap_if(first[0u], first[7u], compare);
swap_if(first[15u], first[22u], compare);
swap_if(first[4u], first[11u], compare);
swap_if(first[6u], first[12u], compare);
swap_if(first[10u], first[16u], compare);
swap_if(first[8u], first[18u], compare);
swap_if(first[14u], first[19u], compare);
swap_if(first[3u], first[8u], compare);
swap_if(first[4u], first[14u], compare);
swap_if(first[11u], first[18u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[0u], first[9u], compare);
swap_if(first[13u], first[22u], compare);
swap_if(first[5u], first[15u], compare);
swap_if(first[7u], first[17u], compare);
swap_if(first[1u], first[10u], compare);
swap_if(first[12u], first[21u], compare);
swap_if(first[8u], first[19u], compare);
swap_if(first[17u], first[22u], compare);
swap_if(first[0u], first[5u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[21u], first[22u], compare);
swap_if(first[0u], first[1u], compare);
swap_if(first[19u], first[22u], compare);
swap_if(first[0u], first[3u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[6u], first[15u], compare);
swap_if(first[7u], first[16u], compare);
swap_if(first[8u], first[11u], compare);
swap_if(first[11u], first[14u], compare);
swap_if(first[4u], first[11u], compare);
swap_if(first[6u], first[8u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[17u], first[20u], compare);
swap_if(first[2u], first[5u], compare);
swap_if(first[9u], first[12u], compare);
swap_if(first[10u], first[13u], compare);
swap_if(first[15u], first[18u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[4u], first[7u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[7u], first[15u], compare);
swap_if(first[3u], first[9u], compare);
swap_if(first[13u], first[19u], compare);
swap_if(first[16u], first[18u], compare);
swap_if(first[8u], first[14u], compare);
swap_if(first[4u], first[6u], compare);
swap_if(first[18u], first[21u], compare);
swap_if(first[1u], first[4u], compare);
swap_if(first[19u], first[21u], compare);
swap_if(first[1u], first[3u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[4u], first[9u], compare);
swap_if(first[13u], first[18u], compare);
swap_if(first[19u], first[20u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[18u], first[20u], compare);
swap_if(first[2u], first[4u], compare);
swap_if(first[5u], first[17u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[5u], first[8u], compare);
swap_if(first[14u], first[17u], compare);
swap_if(first[3u], first[5u], compare);
swap_if(first[17u], first[19u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[11u], first[16u], compare);
swap_if(first[13u], first[16u], compare);
swap_if(first[6u], first[9u], compare);
swap_if(first[16u], first[17u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[4u], first[5u], compare);
swap_if(first[7u], first[9u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[12u], first[15u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[13u], first[15u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[10u], first[14u], compare);
swap_if(first[6u], first[11u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[15u], first[16u], compare);
swap_if(first[6u], first[7u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[9u], first[12u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[13u], first[14u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[11u], first[12u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[11u], first[12u], compare);
}
swap_if
效用函数将两个参数x
和y
与谓词compare
进行比较,并在compare(y, x)
时对其进行交换。我的示例使用了通用的swap_if
函数,但如果您知道将unsigned short
值与operator<
进行比较,则可以使用优化版本(如果您可能不需要这样的函数)编译器识别并优化了比较交换,但不幸的是,并非所有编译器都这样做 - 我使用g ++ 5.2和-O3
,我仍然需要以下函数来提高性能):
void swap_if(unsigned short& x, unsigned short& y)
{
unsigned short dx = x;
unsigned short dy = y;
unsigned short tmp = x = std::min(dx, dy);
y ^= dx ^ tmp;
}
现在,为了确保它确实更快,我决定在需要时将时间std::nth_element
分别只对前10个元素进行部分排序与对整个23个元素进行排序排序网络(使用不同的混洗阵列1000000次)。这是我得到的:
std::nth_element 1158ms
network_sort23 487ms
那就是说,我的电脑已经运行了一段时间并且有点慢,但性能差异很大。我相信当我重新启动计算机时,这种差异将保持不变。我可以稍后再试,让你知道。
关于如何生成这些时间,我使用了this benchmark中cpp-sort library的修改版本。原始排序网络和swap_if
函数也来自那里,因此您可以确保它们已经过多次测试:)
编辑:现在结果是我重新启动了计算机。 network_sort23
版本仍然比std::nth_element
快两倍:
std::nth_element 369ms
network_sort23 154ms
EDIT²:如果您只需要中位数,那么您可以简单地删除不需要的比较交换单元来计算将在第11位的最终值。随后得到的大小为23的中位数查找网络使用与前一个不同的23号排序网络,结果略好一些:
swap_if(first[0u], first[1u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[4u], first[5u], compare);
swap_if(first[6u], first[7u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[1u], first[3u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[9u], first[11u], compare);
swap_if(first[0u], first[2u], compare);
swap_if(first[4u], first[6u], compare);
swap_if(first[8u], first[10u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[5u], first[9u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[3u], first[7u], compare);
swap_if(first[4u], first[8u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[1u], first[4u], compare);
swap_if(first[7u], first[10u], compare);
swap_if(first[3u], first[8u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[2u], first[4u], compare);
swap_if(first[7u], first[9u], compare);
swap_if(first[3u], first[5u], compare);
swap_if(first[6u], first[8u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[16u], first[17u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[13u], first[15u], compare);
swap_if(first[17u], first[19u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[16u], first[18u], compare);
swap_if(first[20u], first[22u], compare);
swap_if(first[13u], first[14u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[21u], first[22u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[17u], first[21u], compare);
swap_if(first[14u], first[18u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[15u], first[19u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[13u], first[16u], compare);
swap_if(first[19u], first[22u], compare);
swap_if(first[15u], first[20u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[19u], first[21u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[18u], first[20u], compare);
swap_if(first[15u], first[16u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[19u], first[20u], compare);
swap_if(first[0u], first[12u], compare);
swap_if(first[2u], first[14u], compare);
swap_if(first[4u], first[16u], compare);
swap_if(first[6u], first[18u], compare);
swap_if(first[8u], first[20u], compare);
swap_if(first[10u], first[22u], compare);
swap_if(first[2u], first[12u], compare);
swap_if(first[10u], first[20u], compare);
swap_if(first[4u], first[12u], compare);
swap_if(first[6u], first[14u], compare);
swap_if(first[8u], first[16u], compare);
swap_if(first[10u], first[18u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[10u], first[14u], compare);
swap_if(first[10u], first[12u], compare);
swap_if(first[1u], first[13u], compare);
swap_if(first[3u], first[15u], compare);
swap_if(first[5u], first[17u], compare);
swap_if(first[7u], first[19u], compare);
swap_if(first[9u], first[21u], compare);
swap_if(first[3u], first[13u], compare);
swap_if(first[11u], first[21u], compare);
swap_if(first[5u], first[13u], compare);
swap_if(first[7u], first[15u], compare);
swap_if(first[9u], first[17u], compare);
swap_if(first[11u], first[19u], compare);
swap_if(first[9u], first[13u], compare);
swap_if(first[11u], first[15u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[11u], first[12u], compare);
可能有更智能的方法来生成中位数发现网络,但我认为没有对该主题进行过广泛的研究。因此,它可能是您现在可以使用的最佳方法。结果并不令人敬畏,但它仍然使用104个比较交换单元而不是118个。
答案 1 :(得分:4)
查看MSVC2013中std::nth_element
的源代码,似乎 N <= 32 的情况通过插入排序解决。这意味着STL实现者意识到,随着这些大小的渐近渐近,随机分区的运行会更慢。
提高性能的方法之一是优化排序算法。 @Morwenn's answer展示了如何使用排序网络对23个元素进行排序,这已知是对小型恒定大小数组进行排序的最快方法之一。 我将研究另一种方法,即在没有排序算法的情况下计算中值。实际上,我根本不会置换输入数组。
由于我们讨论的是小型数组,我们需要以最简单的方式实现一些 O(N ^ 2)算法。理想情况下,它应该根本没有分支,或者只有完全可预测的分支。此外,算法的简单结构可以让我们对其进行矢量化,从而进一步提高其性能。
我已决定使用计数方法,该方法使用here来加速小型线性搜索。首先,假设所有元素都不同。选择数组的任何元素:小于它的元素数定义它在排序数组中的位置。我们可以遍历所有元素,并为每个元素计算小于它的元素数量。如果排序的索引具有所需的值,我们可以停止算法。
不幸的是,一般情况下可能存在相同的元素。我们必须使我们的算法显着更慢,更复杂,以处理它们。我们可以计算可能的排序索引的间隔,而不是计算元素的唯一排序索引。对于任何元素,只需计算小于它的元素数( L )和等于它的元素数( E ),然后排序索引适合范围 [L,L + R] 。如果此间隔包含所需的排序索引(即 N / 2 ),那么我们可以停止算法并返回所考虑的元素。
for (size_t i = 0; i < n; i++) {
auto x = arr[i];
//count number of "less" and "equal" elements
int cntLess = 0, cntEq = 0;
for (size_t j = 0; j < n; j++) {
cntLess += arr[j] < x;
cntEq += arr[j] == x;
}
//fast range checking from here: https://stackoverflow.com/a/17095534/556899
if ((unsigned int)(idx - cntLess) < cntEq)
return x;
}
构造的算法只有一个分支,这是可以预测的:它在所有情况下都失败了,除了我们停止算法的唯一情况。该算法易于使用每个SSE寄存器的8个元素进行矢量化。由于我们必须在最后一个元素之后访问一些元素,我将假设输入数组填充 max = 2 ^ 15-1 值,最多24或32个元素。
第一种方法是通过j
向内循环。在这种情况下,内循环只执行3次,但必须在完成后进行两次8宽减少。他们比内圈本身吃的时间更长。结果,这种矢量化效率不高。
第二种方法是通过i
向外循环。在这种情况下,我们一次处理8个元素x = arr[i]
。对于每个包,我们将它与内循环中的每个元素arr[j]
进行比较。在内循环之后,我们对整个8个元素组进行矢量化范围检查。如果它们中的任何一个成功,我们使用简单的标量代码确定确切的数字(无论如何都会占用很少的时间)。
__m128i idxV = _mm_set1_epi16(idx);
for (size_t i = 0; i < n; i += 8) {
//load pack of 8 elements
auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
//count number of less/equal elements for each element in the pack
__m128i cntLess = _mm_setzero_si128();
__m128i cntEq = _mm_setzero_si128();
for (size_t j = 0; j < n; j++) {
__m128i vAll = _mm_set1_epi16(arr[j]);
cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
}
//perform range check for 8 elements at once
__m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
if (int bm = _mm_movemask_epi8(mask)) {
//range check succeeds for one of the elements, find and return it
for (int t = 0; t < 8; t++)
if (bm & (1 << (2*t)))
return arr[i + t];
}
}
这里我们看到最内层循环中的_mm_set1_epi16
内在。海湾合作委员会似乎有一些性能问题。无论如何,它是在每个最里面的迭代上吃的时间,如果我们在最里面的循环中同时处理8个元素,这可以减少。在这种情况下,我们可以执行一个向量化加载和14个解包指令,以获得8个元素的vAll
。此外,我们必须为循环体中的八个元素编写比较和计数代码,因此它也可以作为8x展开。生成的代码是最快的代码,可以在下面找到它的链接。
我已经在Ivy Bridge 3.4 Ghz处理器上对各种解决方案进行了基准测试。您可以在下面看到 2 ^ 23~ = 8M 次呼叫的总计算时间(以秒为单位)(第一个数字)。第二个数字是结果的校验和。
MSVC 2013 x64( / O2 )的结果:
memcpy only: 0.020
std::nth_element: 2.110 (1186136064)
network sort: 0.630 (1186136064) //solution by @Morwenn (I had to change swap_if)
trivial count: 2.266 (1186136064) //scalar algorithm (presented above)
vectorized count: 0.692 (1186136064) //vectorization by j
vectorized count (T): 0.602 (1186136064) //vectorization by i (presented above)
vectorized count (both): 0.450 (1186136064) //vectorization by i and j
MinGW GCC 4.8.3 x64( -O3 -msse4 )的结果:
memcpy only: 0.016
std::nth_element: 1.981 (1095237632)
network sort: 0.531 (1095237632) //original swap_if used
trivial count: 1.482 (1095237632)
vectorized count: 0.655 (1095237632)
vectorized count (T): 2.668 (1095237632) //GCC generates some crap
vectorized count (both): 0.374 (1095237632)
如您所见,针对23个16位元素的建议矢量化算法比基于排序的方法快一点(BTW,在较旧的CPU上,我只看到5%的时间差)。 如果您可以保证所有元素都不同,您可以简化算法,使其更快。
所有算法的完整代码均可用here,包括所有测试代码。
答案 2 :(得分:2)
我发现这个问题很有趣,所以我尝试了所有我能想到的算法。
结果如下:
testing 100000 repetitions
variant 0, no-op (for overhead measure)
5 ms
variant 1, vector + nth_element
205 ms
variant 2, multiset + advance
745 ms
variant 2b, set (not fully conformant)
787 ms
variant 3, list + lower_bound
589 ms
variant 3b, list + block-allocator
269 ms
variant 4, avl-tree + insert_sorted
645 ms
variant 4b, avl-tree + prune
682 ms
variant 5, histogram
1429 ms
我想我们可以得出结论,你已经在哪里使用最快的
算法。男孩,我错了。但是,如果你能接受一个近似的答案,
可能有更快的方法,例如median of medians
如果您有兴趣,来源是here。