我有一个4个long的数组,我想要计算给定范围内的设置位数。这是我目前使用的函数(其中bitcount(uint64_t)
是一个内联asm函数,它给出了参数中设置位的数量):
unsigned count_bits(const uint64_t *long_ptr, size_t begin, size_t end)
{
uint64_t start_mask = ~((1L << (begin & 63)) - 1);
uint64_t end_mask = ((1L << (end & 63)) - 1);
if (begin >= 0 && begin < 64) {
if (end < 64) {
return bitcount(long_ptr[0] & start_mask & end_mask);
} else if (end < 128) {
return bitcount(long_ptr[0] & start_mask) + bitcount(long_ptr[1] & end_mask);
} else if (end < 192) {
return bitcount(long_ptr[0] & start_mask) + bitcount(long_ptr[1]) + bitcount(long_ptr[2] & end_mask);
} else if (end<256) {
return bitcount(long_ptr[0] & start_mask) + bitcount(long_ptr[1]) + bitcount(long_ptr[2]) + bitcount(long_ptr[3] & end_mask);
} else {
return bitcount(long_ptr[0] & start_mask) + bitcount(long_ptr[1]) + bitcount(long_ptr[2]) + bitcount(long_ptr[3]);
}
} else if (begin >= 64 && begin < 128) {
if (end < 128) {
return bitcount(long_ptr[1] & start_mask & end_mask);
} else if (end < 192) {
return bitcount(long_ptr[1] & start_mask) + bitcount(long_ptr[2] & end_mask);
} else if (end < 256) {
return bitcount(long_ptr[1] & start_mask) + bitcount(long_ptr[2]) + bitcount(long_ptr[3] & end_mask);
} else {
return bitcount(long_ptr[1] & start_mask) + bitcount(long_ptr[2]) + bitcount(long_ptr[3]);
}
} else if (begin >= 128 && begin < 192) {
if (end < 192) {
return bitcount(long_ptr[2] & start_mask & end_mask);
} else if (end < 256) {
return bitcount(long_ptr[2] & start_mask) + bitcount(long_ptr[3] & end_mask);
} else {
return bitcount(long_ptr[2] & start_mask) + bitcount(long_ptr[3]);
}
} else if (begin<256) {
if (end < 256) {
return bitcount(long_ptr[3] & start_mask & end_mask);
} else {
return bitcount(long_ptr[3] & start_mask);
}
} else {
return 0;
}
}
我发现这段代码的性能非常好,但我想知道是否有什么办法可以让它更快,或者重新设计算法可以带来性能提升。
答案 0 :(得分:2)
我创建了2个不同的版本,零分支,我相信大卫Wohlferd评论应该选择适当的紧凑性。我不相信任何分支版本都不会更快。处理器分支预测可以有效地消除对一致数据的跳跃。虽然没有分支的人会一直计算4次比特(除非SSE?)。我在这里发布我的第二个(非常短的)无分支版本。首先是比特计算很复杂。
unsigned bitcount2(const uint64_t *long_ptr, size_t begin, size_t end)
{
uint64_t mask[] = { 0, 0, 0, ~((1ULL << (begin & 63)) - 1), -1LL, -1LL, -1LL, ((1ULL << (end & 63)) - 1), 0, 0, 0 };
uint64_t* b_start = mask+(3 - begin / 64);
uint64_t* b_end = mask + (7 - end / 64);
return bitcount(long_ptr[0] & b_start[0] & b_end[0]) +
bitcount(long_ptr[1] & b_start[1] & b_end[1]) +
bitcount(long_ptr[2] & b_start[2] & b_end[2]) +
bitcount(long_ptr[3] & b_start[3] & b_end[3]);
}