最近,我发现AVX2没有__m256i的popcount,而我发现做类似操作的唯一方法是遵循Wojciech Mula算法:
__m256i count(__m256i v) {
__m256i lookup = _mm256_setr_epi8(0, 1, 1, 2, 1, 2, 2, 3, 1, 2,
2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3,
1, 2, 2, 3, 2, 3, 3, 4);
__m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i lo =_mm256_and_si256(v,low_mask);
__m256i hi = _mm256_and_si256( _mm256_srli_epi32(v, 4), low_mask);
__m256i popcnt1 = _mm256_shuffle_epi8(lookup,lo);
__m256i popcnt2 = _mm256_shuffle_epi8(lookup,hi);
__m256i total = _mm256_add_epi8(popcnt1,popcnt2);
return _mm256_sad_epu8(total,_mm256_setzero_si256());
}
问题在于它返回的是8 short到long的总和,而不是4 short到int的总和。
正在发生的事情:
我有__m256i x,其中包含这8个32位int:
__ m256i res = count(x);
res包含:
结果是4个长64位
期望:
我有__m256i x,其中包含那8 32位整数:
__ m256i res = count(x);
res包含:
结果是8 int 32位。
希望我很清楚,不要犹豫,请我提供更多精度。
谢谢。
答案 0 :(得分:2)
您引用的原始代码依赖于_mm256_sad_epu8
内部函数,它专门用于求和64位字中的字节。
要获得相同的结果(包括32位字的总和),您需要做一些稍有不同的事情。以下应该起作用:
__m256i popcount_pshufb32(__m256i v) {
__m256i lookup = = _mm256_setr_epi8(0, 1, 1, 2, 1, 2, 2, 3, 1, 2,
2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3,
1, 2, 2, 3, 2, 3, 3, 4);
__m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i lo = _mm256_and_si256(v, low_mask);
__m256i hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), low_mask);
__m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo);
__m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi);
__m256i sum8 = _mm256_add_epi8(popcnt1, popcnt2);
return _mm256_srli_epi32(
_mm256_mullo_epi32(sum8, _mm256_set1_epi32(0x01010101)), 24);
}
所以我们用乘法和移位代替_mm256_sad_epu8
。那应该是合理的。在我的测试中,it is slightly slower than the original 64-bit version, but the difference is relatively small。
您可以使用更多的内在函数来获得更好的性能:
__m256i popcount_pshufb32(__m256i v) {
__m256i lookup = = _mm256_setr_epi8(0, 1, 1, 2, 1, 2, 2, 3, 1, 2,
2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3,
1, 2, 2, 3, 2, 3, 3, 4);
__m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i lo = _mm256_and_si256(v, low_mask);
__m256i hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), low_mask);
__m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo);
__m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi);
__m256i sum8 = _mm256_add_epi8(popcnt1, popcnt2);
return _mm256_madd_epi16(_mm256_maddubs_epi16(sum8, _mm256_set1_epi8(1)),
_mm256_set1_epi16(1));
}