每个位列上的AVX2列填充计数算法分别

时间:2019-10-21 12:18:01

标签: c++ visual-c++ x86 simd avx2

对于我正在从事的项目,我需要计算翻录的PDF图像数据中每列的设置位数。

我正在尝试获取整个PDF作业(所有页面)中每一列的总设置位数。

数据一旦被翻录,就存储在MemoryMappedFile中,没有备份文件(在内存中)。

PDF页面尺寸为13952像素x 15125像素。可以通过将PDF(以像素为单位)的长度(高度)乘以以字节为单位的宽度来计算所得的翻录数据的总大小。翻录的数据为1 bit == 1 pixel。因此,翻录页面的大小(以字节为单位)为(13952 / 8) * 15125

请注意,宽度始终是64 bits的倍数。

在翻录后,我必须计算PDF每页(可能是几万页)中每一列的设置位。

我首先用一种基本的解决方法解决了这个问题,即仅遍历每个字节并计算设置的位数,然后将结果放入vector中。此后,我将算法缩减为以下所示的内容。我已经从大约350ms的执行时间变成了大约120ms。

static void count_dots( )
{
    using namespace diag;
    using namespace std::chrono;

    std::vector<std::size_t> dot_counts( 13952, 0 );
    uint64_t* ptr_dot_counts{ dot_counts.data( ) };

    std::vector<uint64_t> ripped_pdf_data( 3297250, 0xFFFFFFFFFFFFFFFFUL );
    const uint64_t* ptr_data{ ripped_pdf_data.data( ) };

    std::size_t line_count{ 0 };
    std::size_t counter{ ripped_pdf_data.size( ) };

    stopwatch sw;
    sw.start( );

    while( counter > 0 )
    {
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000000001UL ) >> 0;

        ++ptr_data;
        --counter;
        if( ++line_count >= 218 )
        {
            ptr_dot_counts = dot_counts.data( );
            line_count = 0;
        }
    }   

    sw.stop( );
    std::cout << sw.elapsed<milliseconds>( ) << "ms\n";
}

不幸的是,这仍然会增加很多额外的处理时间,这是无法接受的。

上面的代码很丑陋,不会赢得任何选美比赛,但它有助于减少执行时间。 自从我写原始版本以来,我已经完成了以下工作:

  • 使用pointers代替indexers
  • uint64而不是uint8的块形式处理数据
  • 手动展开for循环以遍历bit的每个byte中的每个uint64
  • 在掩蔽后使用最终bit shift而不是__popcnt64来计数集合bit

对于此测试,我正在生成伪造的数据,其中每个bit都设置为1。测试完成后,dot_counts vector应该为每个15125包含element

我希望这里的一些人可以帮助我使算法的平均执行时间低于100毫秒。 我不在乎这里的可移植性。

  • 目标计算机的CPU:Xeon E5-2680 v4 - Intel
  • 编译器:MSVC++ 14.23
  • 操作系统:Windows 10
  • C ++版本:C++17
  • 编译器标志:/O2 /arch:AVX2

大约8年前,有人问了一个非常类似的问题: How to quickly count bits into separate bins in a series of ints on Sandy Bridge?

(编者注:也许您错过了Count each bit-position separately over many 64-bit bitmasks, with AVX but not AVX2,它提供了一些最新的更快的答案,至少是因为在连续的内存中沿着一列而不是沿着一行前进。也许您可以走1或2条高速缓存行向下一列,这样您就可以使SIMD寄存器中的计数器保持高温。)

当我将到目前为止所获得的答案与公认的答案进行比较时,我就非常接近了。我已经在处理uint64而不是uint8的块了。我只是想知道是否还有更多可以做的事情,无论是使用内在函数,汇编还是简单的事情,例如更改我正在使用的数据结构。

1 个答案:

答案 0 :(得分:4)

可以使用带有标记的AVX2来完成。

为了使此工作正常进行,我建议使用vector<uint16_t>作为计数。增加计数是最大的问题,而我们需要扩展的越多,问题就越大。 uint16_t足以计数一页,因此您一次可以计数一页,并将计数器添加到总计的一组较宽计数器中。这是一些开销,但是要比在主循环中扩大宽度要少得多。

计数的big-endian排序非常烦人,引入了更多洗牌才能正确处理。因此,我建议将其错误并稍后重新排序(也许是在将它们求和成总数时?)。可以免费维护“先右移7位,然后移6位,然后移5位”的顺序,因为我们可以根据需要选择64位块的移位计数。因此,在下面的代码中,计数的实际顺序为:

  • 最低有效字节的第7位
  • 第二个字节的第7位
  • ...
  • 最高有效字节的第7位
  • 最低有效字节的第6位
  • ...

所以每8个一组反转一次。 (至少这是我打算做的,AVX2 unpacks令人困惑)

代码(未经测试):

while( counter > 0 )
{
    __m256i data = _mm256_set1_epi64x(*ptr_data);        
    __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
    __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
    data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
    data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));

    __m256i zero = _mm256_setzero_si256();

    __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
    c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
    c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
    c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
    c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);

    ptr_dot_counts += 64;
    ++ptr_data;
    --counter;
    if( ++line_count >= 218 )
    {
        ptr_dot_counts = dot_counts.data( );
        line_count = 0;
    }
}

可以进一步展开,一次处理多个行。这样做是有好处的,因为如前所述,对计数器求和是最大的问题,并且按行展开将少做这些,而对寄存器进行更简单的求和。

使用的Somme内部函数:

  • _mm256_set1_epi64x,将一个int64_t复制到向量的所有4个64位元素中。 uint64_t也可以。
  • _mm256_set_epi64x,将4个64位值转换为向量。
  • _mm256_srlv_epi64,右移逻辑,计数可变(每个元素的计数可以不同)。
  • _mm256_and_si256,按位与。
  • 另外,
  • _mm256_add_epi16适用于16位元素。
  • _mm256_unpacklo_epi8_mm256_unpackhi_epi8,可能由该页面上的图表最好地解释了

可以使用一个uint64_t来保存64个独立和的所有第0位,使用另一个uint64_t来保存和的所有第1位,以“垂直”求和。可以通过按位算法模拟完整加法器(电路组件)来完成。然后,不只是将0或1加到计数器上,而是一次全部添加了更大的数字。

垂直和也可以向量化,但这会显着膨胀将垂直和添加到列和的代码,所以我在这里没有这样做。它应该会有所帮助,但这只是很多代码。

示例(未经测试):

size_t y;
// sum 7 rows at once
for (y = 0; (y + 6) < 15125; y += 7) {
    ptr_dot_counts = dot_counts.data( );
    ptr_data = ripped_pdf_data.data( ) + y * 218;
    for (size_t x = 0; x < 218; x++) {
        uint64_t dataA = ptr_data[0];
        uint64_t dataB = ptr_data[218];
        uint64_t dataC = ptr_data[218 * 2];
        uint64_t dataD = ptr_data[218 * 3];
        uint64_t dataE = ptr_data[218 * 4];
        uint64_t dataF = ptr_data[218 * 5];
        uint64_t dataG = ptr_data[218 * 6];
        // vertical sums, 7 bits to 3
        uint64_t abc0 = (dataA ^ dataB) ^ dataC;
        uint64_t abc1 = (dataA ^ dataB) & dataC | (dataA & dataB);
        uint64_t def0 = (dataD ^ dataE) ^ dataF;
        uint64_t def1 = (dataD ^ dataE) & dataF | (dataD & dataE);
        uint64_t bit0 = (abc0 ^ def0) ^ dataG;
        uint64_t c1   = (abc0 ^ def0) & dataG | (abc0 & def0);
        uint64_t bit1 = (abc1 ^ def1) ^ c1;
        uint64_t bit2 = (abc1 ^ def1) & c1 | (abc1 & def1);
        // add vertical sums to column counts
        __m256i bit0v = _mm256_set1_epi64x(bit0);
        __m256i data01 = _mm256_srlv_epi64(bit0v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data02 = _mm256_srlv_epi64(bit0v, _mm256_set_epi64x(0, 2, 1, 3));
        data01 = _mm256_and_si256(data01, _mm256_set1_epi8(1));
        data02 = _mm256_and_si256(data02, _mm256_set1_epi8(1));
        __m256i bit1v = _mm256_set1_epi64x(bit1);
        __m256i data11 = _mm256_srlv_epi64(bit1v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data12 = _mm256_srlv_epi64(bit1v, _mm256_set_epi64x(0, 2, 1, 3));
        data11 = _mm256_and_si256(data11, _mm256_set1_epi8(1));
        data12 = _mm256_and_si256(data12, _mm256_set1_epi8(1));
        data11 = _mm256_add_epi8(data11, data11);
        data12 = _mm256_add_epi8(data12, data12);
        __m256i bit2v = _mm256_set1_epi64x(bit2);
        __m256i data21 = _mm256_srlv_epi64(bit2v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data22 = _mm256_srlv_epi64(bit2v, _mm256_set_epi64x(0, 2, 1, 3));
        data21 = _mm256_and_si256(data21, _mm256_set1_epi8(1));
        data22 = _mm256_and_si256(data22, _mm256_set1_epi8(1));
        data21 = _mm256_slli_epi16(data21, 2);
        data22 = _mm256_slli_epi16(data22, 2);
        __m256i data1 = _mm256_add_epi8(_mm256_add_epi8(data01, data11), data21);
        __m256i data2 = _mm256_add_epi8(_mm256_add_epi8(data02, data12), data22);

        __m256i zero = _mm256_setzero_si256();

        __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);


        ptr_dot_counts += 64;
        ++ptr_data;
    }
}
// leftover rows
for (; y < 15125; y++) {
    ptr_dot_counts = dot_counts.data( );
    ptr_data = ripped_pdf_data.data( ) + y * 218;
    for (size_t x = 0; x < 218; x++) {
        __m256i data = _mm256_set1_epi64x(*ptr_data);
        __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
        data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
        data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));

        __m256i zero = _mm256_setzero_si256();

        __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);


        ptr_dot_counts += 64;
        ++ptr_data;
    }
}

到目前为止,第二好的方法是一种更简单的方法,除了第一次利用快速的8位求和来运行yloopLen行之外,更像第一个版本:

size_t yloopLen = 32;
size_t yblock = yloopLen * 1;
size_t yy;
for (yy = 0; yy < 15125; yy += yblock) {
    for (size_t x = 0; x < 218; x++) {
        ptr_data = ripped_pdf_data.data() + x;
        ptr_dot_counts = dot_counts.data() + x * 64;
        __m256i zero = _mm256_setzero_si256();

        __m256i c1 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        __m256i c2 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        __m256i c3 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        __m256i c4 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);

        size_t end = std::min(yy + yblock, size_t(15125));
        size_t y;
        for (y = yy; y < end; y += yloopLen) {
            size_t len = std::min(size_t(yloopLen), end - y);
            __m256i count1 = zero;
            __m256i count2 = zero;

            for (size_t t = 0; t < len; t++) {
                __m256i data = _mm256_set1_epi64x(ptr_data[(y + t) * 218]);
                __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
                __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
                data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
                data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));
                count1 = _mm256_add_epi8(count1, data1);
                count2 = _mm256_add_epi8(count2, data2);
            }

            c1 = _mm256_add_epi16(_mm256_unpacklo_epi8(count1, zero), c1);
            c2 = _mm256_add_epi16(_mm256_unpackhi_epi8(count1, zero), c2);
            c3 = _mm256_add_epi16(_mm256_unpacklo_epi8(count2, zero), c3);
            c4 = _mm256_add_epi16(_mm256_unpackhi_epi8(count2, zero), c4);
        }

        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c1);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c2);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c3);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c4);
    }
}

以前存在一些测量问题,最终这实际上并没有比上面的“垂直和”版本更好,但也没有差很多。