使用AVX2查找非零字节的索引

时间:2017-08-14 11:08:44

标签: simd avx2

如何使用128位向量https://stackoverflow.com/a/41959079/3648510找到非零字节索引的解决方案。

另一个解决方案(arr2ind_pext)演示了如何在256位向量中查找非零字节,但将索引返回为4字节整数 - https://stackoverflow.com/a/41958528/3648510

我的初衷是修改arr2ind_pext解决方案以返回8位而不是32位索引。

但是现在我认为32位可能没问题,但我想要的是获得一个解决方案,它将尽可能快地找到256位(或两次迭代中最多512位)矢量的索引。

我目前基于arr2ind_pext的解决方案在这里:

unsigned nonzeros(const __m256i _in, __m256i& _ivec, unsigned* _out)
{
    uint64_t cntr_const = 0xFEDCBA9876543210;
    __m256i  shft       = _mm256_set_epi64x(0x04,0x00,0x04,0x00);
    __m256i  vmsk       = _mm256_set1_epi8(0x0F);
    __m256i  shf_lo     = _mm256_set_epi8(
        0x80, 0x80, 0x80, 0x0B,  0x80, 0x80, 0x80, 0x03,  0x80, 0x80, 0x80, 0x0A,  0x80, 0x80, 0x80, 0x02,
        0x80, 0x80, 0x80, 0x09,  0x80, 0x80, 0x80, 0x01,  0x80, 0x80, 0x80, 0x08,  0x80, 0x80, 0x80, 0x00
    );
    __m256i  shf_hi     = _mm256_set_epi8(
        0x80, 0x80, 0x80, 0x0F,  0x80, 0x80, 0x80, 0x07,  0x80, 0x80, 0x80, 0x0E,  0x80, 0x80, 0x80, 0x06,
        0x80, 0x80, 0x80, 0x0D,  0x80, 0x80, 0x80, 0x05,  0x80, 0x80, 0x80, 0x0C,  0x80, 0x80, 0x80, 0x04
    );
    __m256i  pshufbcnst = _mm256_set_epi8(
        0x80, 0x80, 0x80, 0x80,  0x80, 0x80, 0x80, 0x80,  0x1E, 0x1C, 0x1A, 0x18,  0x16, 0x14, 0x12, 0x10,
        0x80, 0x80, 0x80, 0x80,  0x80, 0x80, 0x80, 0x80,  0x0E, 0x0C, 0x0A, 0x08,  0x06, 0x04, 0x02, 0x00
    );

    __m256i  msk        = _mm256_cmpeq_epi8(_in, _mm256_setzero_si256()); // Generate 32 bit mask
             msk        = _mm256_srli_epi64(msk, 4);                      // Pack 32x8 bit mask to 32x4 bit mask
             msk        = _mm256_shuffle_epi8(msk, pshufbcnst);           // Pack 32x8 bit mask to 32x4 bit mask
             msk        = _mm256_xor_si256(msk, _mm256_set1_epi8(-1));    // Invert 32x4 mask

    uint64_t m64_0 = _mm256_extract_epi64(msk, 0);
    uint64_t m64_1 = _mm256_extract_epi64(msk, 2);
    unsigned m64_count_0 = _mm_popcnt_u64(m64_0) >> 2;             // p is the number of nonzeros in 16 bytes of a
    unsigned m64_count_1 = _mm_popcnt_u64(m64_1) >> 2;             // p is the number of nonzeros in 16 bytes of a
    unsigned* out_0 = &_out[0];
    unsigned* out_1 = &_out[m64_count_0];

    auto f = [&](uint64_t msk64, unsigned* __restrict__  _out)
    {
        uint64_t cntr       = _pext_u64(cntr_const, msk64);           // parallel bits extract. cntr contains p 4-bit integers. The 16 4-bit integers in cntr_const are shuffled to the p 4-bit integers that we want

        // Unpack p 4-bit integers to p 32-bit integers
        __m256i  cntr256    = _mm256_set1_epi64x(cntr);
                 cntr256    = _mm256_srlv_epi64(cntr256, shft);
                 cntr256    = _mm256_and_si256(cntr256, vmsk);
        __m256i  cntr256_lo = _mm256_shuffle_epi8(cntr256, shf_lo);
        __m256i  cntr256_hi = _mm256_shuffle_epi8(cntr256, shf_hi);
                 cntr256_lo = _mm256_add_epi8(_ivec, cntr256_lo);
                 cntr256_hi = _mm256_add_epi8(_ivec, cntr256_hi);

        _mm256_storeu_si256((__m256i *)&_out[0], cntr256_lo);         // Note that the stores of iteration i and i+16 may overlap
        _mm256_storeu_si256((__m256i *)&_out[8], cntr256_hi);     // Array ind has to be large enough to avoid segfaults. At most 16 integers are written more than strictly necessary
    };

    f(m64_0, out_0);

    _ivec = _mm256_add_epi32(_ivec, _mm256_set1_epi32(16));
    f(m64_1, out_1);

    return m64_count_0 + m64_count_1;
}

0 个答案:

没有答案