非常频繁地调用std :: nth_element()函数

时间:2015-10-23 17:11:27

标签: c++ performance sorting inline nth-element

我没有在任何地方找到这个特定主题......

我在23个整数的std :: vector中的不同数据上调用nth_element()算法,每秒大约400,000次,更精确的“无符号短”值。

我想提高计算速度,这个特定的调用需要很大一部分CPU时间。 现在我注意到,与std :: sort()一样,即使具有最高优化级别和NDEBUG模式(Linux Clang编译器),nth_element函数在探查器中也是可见的,因此比较是内联的而不是函数调用本身。好吧,更多的前提:不是nth_element()但是std :: __ introselect()是可见的。

由于数据的大小很小,我尝试使用二次排序函数PIKSORT,当数据大小小于20个元素时,它通常比调用std :: sort更快,可能是因为函数将是内联的

template <class CONTAINER>
inline void piksort(CONTAINER& arr)  // indeed this is "insertion sort"
{
    typename CONTAINER::value_type a;

    const int n = (int)arr.size();
    for (int j = 1; j<n; ++j) {
        a = arr[j];
        int i = j;
        while (i > 0 && a < arr[i - 1]) {
            arr[i] = arr[i - 1];
            i--;
        }
        arr[i] = a;
    }
}

然而,这比在这种情况下使用nth_element慢。

另外,使用统计方法不合适,Something faster than std::nth_element

最后,由于值在0到约20000的范围内,因此直方图方法看起来不合适。

我的问题:有没有人知道一个简单的解决方案?我想我可能不是唯一一个经常调用std :: sort或nth_element的人。

3 个答案:

答案 0 :(得分:13)

您提到数组的大小始终为23。此外,使用的类型是unsigned short。在这种情况下,您可能会尝试使用大小为23的排序网络;由于您的类型为unsigned short,因此使用排序网络对整个数组进行排序可能比使用std::nth_element对其进行部分排序更快。下面是一个非常直接的C ++ 14实现的大小为23且具有118比较交换单元的排序网络,如Using Symmetry and Evolutionary Search to Minimize Sorting Networks所述:

template<typename RandomIt, typename Compare = std::less<>>
void network_sort23(RandomIt first, Compare compare={})
{
    swap_if(first[1u], first[20u], compare);
    swap_if(first[2u], first[21u], compare);
    swap_if(first[5u], first[13u], compare);
    swap_if(first[9u], first[17u], compare);
    swap_if(first[0u], first[7u], compare);
    swap_if(first[15u], first[22u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[12u], compare);
    swap_if(first[10u], first[16u], compare);
    swap_if(first[8u], first[18u], compare);
    swap_if(first[14u], first[19u], compare);
    swap_if(first[3u], first[8u], compare);
    swap_if(first[4u], first[14u], compare);
    swap_if(first[11u], first[18u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[0u], first[9u], compare);
    swap_if(first[13u], first[22u], compare);
    swap_if(first[5u], first[15u], compare);
    swap_if(first[7u], first[17u], compare);
    swap_if(first[1u], first[10u], compare);
    swap_if(first[12u], first[21u], compare);
    swap_if(first[8u], first[19u], compare);
    swap_if(first[17u], first[22u], compare);
    swap_if(first[0u], first[5u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[21u], first[22u], compare);
    swap_if(first[0u], first[1u], compare);
    swap_if(first[19u], first[22u], compare);
    swap_if(first[0u], first[3u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[6u], first[15u], compare);
    swap_if(first[7u], first[16u], compare);
    swap_if(first[8u], first[11u], compare);
    swap_if(first[11u], first[14u], compare);
    swap_if(first[4u], first[11u], compare);
    swap_if(first[6u], first[8u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[17u], first[20u], compare);
    swap_if(first[2u], first[5u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[10u], first[13u], compare);
    swap_if(first[15u], first[18u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[4u], first[7u], compare);
    swap_if(first[20u], first[21u], compare);
    swap_if(first[1u], first[2u], compare);
    swap_if(first[7u], first[15u], compare);
    swap_if(first[3u], first[9u], compare);
    swap_if(first[13u], first[19u], compare);
    swap_if(first[16u], first[18u], compare);
    swap_if(first[8u], first[14u], compare);
    swap_if(first[4u], first[6u], compare);
    swap_if(first[18u], first[21u], compare);
    swap_if(first[1u], first[4u], compare);
    swap_if(first[19u], first[21u], compare);
    swap_if(first[1u], first[3u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[2u], first[6u], compare);
    swap_if(first[16u], first[20u], compare);
    swap_if(first[4u], first[9u], compare);
    swap_if(first[13u], first[18u], compare);
    swap_if(first[19u], first[20u], compare);
    swap_if(first[2u], first[3u], compare);
    swap_if(first[18u], first[20u], compare);
    swap_if(first[2u], first[4u], compare);
    swap_if(first[5u], first[17u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[8u], compare);
    swap_if(first[14u], first[17u], compare);
    swap_if(first[3u], first[5u], compare);
    swap_if(first[17u], first[19u], compare);
    swap_if(first[3u], first[4u], compare);
    swap_if(first[18u], first[19u], compare);
    swap_if(first[6u], first[10u], compare);
    swap_if(first[11u], first[16u], compare);
    swap_if(first[13u], first[16u], compare);
    swap_if(first[6u], first[9u], compare);
    swap_if(first[16u], first[17u], compare);
    swap_if(first[5u], first[6u], compare);
    swap_if(first[4u], first[5u], compare);
    swap_if(first[7u], first[9u], compare);
    swap_if(first[17u], first[18u], compare);
    swap_if(first[12u], first[15u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[8u], first[12u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[13u], first[15u], compare);
    swap_if(first[15u], first[17u], compare);
    swap_if(first[5u], first[7u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[10u], first[14u], compare);
    swap_if(first[6u], first[11u], compare);
    swap_if(first[14u], first[16u], compare);
    swap_if(first[15u], first[16u], compare);
    swap_if(first[6u], first[7u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[9u], first[12u], compare);
    swap_if(first[11u], first[13u], compare);
    swap_if(first[13u], first[14u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[7u], first[8u], compare);
    swap_if(first[14u], first[15u], compare);
    swap_if(first[9u], first[10u], compare);
    swap_if(first[8u], first[9u], compare);
    swap_if(first[12u], first[14u], compare);
    swap_if(first[11u], first[12u], compare);
    swap_if(first[12u], first[13u], compare);
    swap_if(first[10u], first[11u], compare);
    swap_if(first[11u], first[12u], compare);
}

swap_if效用函数将两个参数xy与谓词compare进行比较,并在compare(y, x)时对其进行交换。我的示例使用了通用的swap_if函数,但如果您知道将unsigned short值与operator<进行比较,则可以使用优化版本(如果您可能不需要这样的函数)编译器识别并优化了比较交换,但不幸的是,并非所有编译器都这样做 - 我使用g ++ 5.2和-O3,我仍然需要以下函数来提高性能):

void swap_if(unsigned short& x, unsigned short& y)
{
    unsigned short dx = x;
    unsigned short dy = y;
    unsigned short tmp = x = std::min(dx, dy);
    y ^= dx ^ tmp;
}

现在,为了确保它确实更快,我决定在需要时将时间std::nth_element分别只对前10个元素进行部分排序与对整个23个元素进行排序排序网络(使用不同的混洗阵列1000000次)。这是我得到的:

std::nth_element    1158ms
network_sort23      487ms

那就是说,我的电脑已经运行了一段时间并且有点慢,但性能差异很大。我相信当我重新启动计算机时,这种差异将保持不变。我可以稍后再试,让你知道。

关于如何生成这些时间,我使用了this benchmarkcpp-sort library的修改版本。原始排序网络和swap_if函数也来自那里,因此您可以确保它们已经过多次测试:)

编辑:现在结果是我重新启动了计算机。 network_sort23版本仍然比std::nth_element快两倍:

std::nth_element    369ms
network_sort23      154ms

EDIT²:如果您只需要中位数,那么您可以简单地删除不需要的比较交换单元来计算将在第11位的最终值。随后得到的大小为23的中位数查找网络使用与前一个不同的23号排序网络,结果略好一些:

swap_if(first[0u], first[1u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[4u], first[5u], compare);
swap_if(first[6u], first[7u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[10u], first[11u], compare);
swap_if(first[1u], first[3u], compare);
swap_if(first[5u], first[7u], compare);
swap_if(first[9u], first[11u], compare);
swap_if(first[0u], first[2u], compare);
swap_if(first[4u], first[6u], compare);
swap_if(first[8u], first[10u], compare);
swap_if(first[1u], first[2u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[9u], first[10u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[5u], first[9u], compare);
swap_if(first[2u], first[6u], compare);
swap_if(first[1u], first[5u], compare);
swap_if(first[6u], first[10u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[3u], first[7u], compare);
swap_if(first[4u], first[8u], compare);
swap_if(first[0u], first[4u], compare);
swap_if(first[7u], first[11u], compare);
swap_if(first[1u], first[4u], compare);
swap_if(first[7u], first[10u], compare);
swap_if(first[3u], first[8u], compare);
swap_if(first[2u], first[3u], compare);
swap_if(first[8u], first[9u], compare);
swap_if(first[2u], first[4u], compare);
swap_if(first[7u], first[9u], compare);
swap_if(first[3u], first[5u], compare);
swap_if(first[6u], first[8u], compare);
swap_if(first[3u], first[4u], compare);
swap_if(first[5u], first[6u], compare);
swap_if(first[7u], first[8u], compare);
swap_if(first[12u], first[13u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[16u], first[17u], compare);
swap_if(first[18u], first[19u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[13u], first[15u], compare);
swap_if(first[17u], first[19u], compare);
swap_if(first[12u], first[14u], compare);
swap_if(first[16u], first[18u], compare);
swap_if(first[20u], first[22u], compare);
swap_if(first[13u], first[14u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[21u], first[22u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[17u], first[21u], compare);
swap_if(first[14u], first[18u], compare);
swap_if(first[13u], first[17u], compare);
swap_if(first[18u], first[22u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[15u], first[19u], compare);
swap_if(first[16u], first[20u], compare);
swap_if(first[12u], first[16u], compare);
swap_if(first[13u], first[16u], compare);
swap_if(first[19u], first[22u], compare);
swap_if(first[15u], first[20u], compare);
swap_if(first[14u], first[15u], compare);
swap_if(first[20u], first[21u], compare);
swap_if(first[14u], first[16u], compare);
swap_if(first[19u], first[21u], compare);
swap_if(first[15u], first[17u], compare);
swap_if(first[18u], first[20u], compare);
swap_if(first[15u], first[16u], compare);
swap_if(first[17u], first[18u], compare);
swap_if(first[19u], first[20u], compare);
swap_if(first[0u], first[12u], compare);
swap_if(first[2u], first[14u], compare);
swap_if(first[4u], first[16u], compare);
swap_if(first[6u], first[18u], compare);
swap_if(first[8u], first[20u], compare);
swap_if(first[10u], first[22u], compare);
swap_if(first[2u], first[12u], compare);
swap_if(first[10u], first[20u], compare);
swap_if(first[4u], first[12u], compare);
swap_if(first[6u], first[14u], compare);
swap_if(first[8u], first[16u], compare);
swap_if(first[10u], first[18u], compare);
swap_if(first[8u], first[12u], compare);
swap_if(first[10u], first[14u], compare);
swap_if(first[10u], first[12u], compare);
swap_if(first[1u], first[13u], compare);
swap_if(first[3u], first[15u], compare);
swap_if(first[5u], first[17u], compare);
swap_if(first[7u], first[19u], compare);
swap_if(first[9u], first[21u], compare);
swap_if(first[3u], first[13u], compare);
swap_if(first[11u], first[21u], compare);
swap_if(first[5u], first[13u], compare);
swap_if(first[7u], first[15u], compare);
swap_if(first[9u], first[17u], compare);
swap_if(first[11u], first[19u], compare);
swap_if(first[9u], first[13u], compare);
swap_if(first[11u], first[15u], compare);
swap_if(first[11u], first[13u], compare);
swap_if(first[11u], first[12u], compare);

可能有更智能的方法来生成中位数发现网络,但我认为没有对该主题进行过广泛的研究。因此,它可能是您现在可以使用的最佳方法。结果并不令人敬畏,但它仍然使用104个比较交换单元而不是118个。

答案 1 :(得分:4)

一般想法

查看MSVC2013中std::nth_element的源代码,似乎 N <= 32 的情况通过插入排序解决。这意味着STL实现者意识到,随着这些大小的渐近渐近,随机分区的运行会更慢。

提高性能的方法之一是优化排序算法。 @Morwenn's answer展示了如何使用排序网络对23个元素进行排序,这已知是对小型恒定大小数组进行排序的最快方法之一。 我将研究另一种方法,即在没有排序算法的情况下计算中值。实际上,我根本不会置换输入数组。

由于我们讨论的是小型数组,我们需要以最简单的方式实现一些 O(N ^ 2)算法。理想情况下,它应该根本没有分支,或者只有完全可预测的分支。此外,算法的简单结构可以让我们对其进行矢量化,从而进一步提高其性能。

算法

我已决定使用计数方法,该方法使用here来加速小型线性搜索。首先,假设所有元素都不同。选择数组的任何元素:小于它的元素数定义它在排序数组中的位置。我们可以遍历所有元素,并为每个元素计算小于它的元素数量。如果排序的索引具有所需的值,我们可以停止算法。

不幸的是,一般情况下可能存在相同的元素。我们必须使我们的算法显着更慢,更复杂,以处理它们。我们可以计算可能的排序索引的间隔,而不是计算元素的唯一排序索引。对于任何元素,只需计算小于它的元素数( L )和等于它的元素数( E ),然后排序索引适合范围 [L,L + R] 。如果此间隔包含所需的排序索引(即 N / 2 ),那么我们可以停止算法并返回所考虑的元素。

for (size_t i = 0; i < n; i++) {
    auto x = arr[i];
    //count number of "less" and "equal" elements
    int cntLess = 0, cntEq = 0;
    for (size_t j = 0; j < n; j++) {
        cntLess += arr[j] < x;
        cntEq += arr[j] == x;
    }
    //fast range checking from here: https://stackoverflow.com/a/17095534/556899
    if ((unsigned int)(idx - cntLess) < cntEq)
        return x;
}

矢量

构造的算法只有一个分支,这是可以预测的:它在所有情况下都失败了,除了我们停止算法的唯一情况。该算法易于使用每个SSE寄存器的8个元素进行矢量化。由于我们必须在最后一个元素之后访问一些元素,我将假设输入数组填充 max = 2 ^ 15-1 值,最多24或32个元素。

第一种方法是通过j向内循环。在这种情况下,内循环只执行3次,但必须在完成后进行两次8宽减少。他们比内圈本身吃的时间更长。结果,这种矢量化效率不高。

第二种方法是通过i向外循环。在这种情况下,我们一次处理8个元素x = arr[i]。对于每个包,我们将它与内循环中的每个元素arr[j]进行比较。在内循环之后,我们对整个8个元素组进行矢量化范围检查。如果它们中的任何一个成功,我们使用简单的标量代码确定确切的数字(无论如何都会占用很少的时间)。

__m128i idxV = _mm_set1_epi16(idx);
for (size_t i = 0; i < n; i += 8) {
    //load pack of 8 elements
    auto xx = _mm_loadu_si128((__m128i*)&arr[i]);
    //count number of less/equal elements for each element in the pack
    __m128i cntLess = _mm_setzero_si128();
    __m128i cntEq = _mm_setzero_si128();
    for (size_t j = 0; j < n; j++) {
        __m128i vAll = _mm_set1_epi16(arr[j]);
        cntLess = _mm_sub_epi16(cntLess, _mm_cmplt_epi16(vAll, xx));
        cntEq = _mm_sub_epi16(cntEq, _mm_cmpeq_epi16(vAll, xx));
    }
    //perform range check for 8 elements at once
    __m128i mask = _mm_andnot_si128(_mm_cmplt_epi16(idxV, cntLess), _mm_cmplt_epi16(idxV, _mm_add_epi16(cntLess, cntEq)));
    if (int bm = _mm_movemask_epi8(mask)) {
        //range check succeeds for one of the elements, find and return it 
        for (int t = 0; t < 8; t++)
            if (bm & (1 << (2*t)))
                return arr[i + t];
    }
}

这里我们看到最内层循环中的_mm_set1_epi16内在。海湾合作委员会似乎有一些性能问题。无论如何,它是在每个最里面的迭代上吃的时间,如果我们在最里面的循环中同时处理8个元素,这可以减少。在这种情况下,我们可以执行一个向量化加载和14个解包指令,以获得8个元素的vAll。此外,我们必须为循环体中的八个元素编写比较和计数代码,因此它也可以作为8x展开。生成的代码是最快的代码,可以在下面找到它的链接。

比较

我已经在Ivy Bridge 3.4 Ghz处理器上对各种解决方案进行了基准测试。您可以在下面看到 2 ^ 23~ = 8M 次呼叫的总计算时间(以秒为单位)(第一个数字)。第二个数字是结果的校验和。

MSVC 2013 x64( / O2 )的结果:

memcpy only: 0.020
std::nth_element: 2.110 (1186136064)
network sort: 0.630 (1186136064)              //solution by @Morwenn (I had to change swap_if)
trivial count: 2.266 (1186136064)             //scalar algorithm (presented above)
vectorized count: 0.692 (1186136064)          //vectorization by j
vectorized count (T): 0.602 (1186136064)      //vectorization by i (presented above)
vectorized count (both): 0.450 (1186136064)   //vectorization by i and j

MinGW GCC 4.8.3 x64( -O3 -msse4 )的结果:

memcpy only: 0.016
std::nth_element: 1.981 (1095237632)
network sort: 0.531 (1095237632)              //original swap_if used
trivial count: 1.482 (1095237632)
vectorized count: 0.655 (1095237632)
vectorized count (T): 2.668 (1095237632)      //GCC generates some crap
vectorized count (both): 0.374 (1095237632)

如您所见,针对23个16位元素的建议矢量化算法比基于排序的方法快一点(BTW,在较旧的CPU上,我只看到5%的时间差)。 如果您可以保证所有元素都不同,您可以简化算法,使其更快。

所有算法的完整代码均可用here,包括所有测试代码。

答案 2 :(得分:2)

我发现这个问题很有趣,所以我尝试了所有我能想到的算法。
结果如下:

testing 100000 repetitions
variant 0, no-op (for overhead measure)
5 ms
variant 1, vector + nth_element
205 ms
variant 2, multiset + advance
745 ms
variant 2b, set (not fully conformant)
787 ms
variant 3, list + lower_bound
589 ms
variant 3b, list + block-allocator
269 ms
variant 4, avl-tree + insert_sorted
645 ms
variant 4b, avl-tree + prune
682 ms
variant 5, histogram
1429 ms

我想我们可以得出结论,你已经在哪里使用最快的 算法。男孩,我错了。但是,如果你能接受一个近似的答案, 可能有更快的方法,例如median of medians 如果您有兴趣,来源是here