位操作:将公共部分保持在最后一个不同位的左侧

时间:2014-02-03 05:54:07

标签: c++ c algorithm optimization bit-manipulation

考虑用二进制编写的两个数字(左边的MSB):

X = x7 x6 x5 x4 x3 x2 x1 x0

Y = y7 y6 y5 y4 y3 y2 y1 y0

这些数字可以具有任意数量的位,但两者的类型相同。现在考虑x7 == y7x6 == y6x5 == y5,但x4 != y4

如何计算:

Z = x7 x6 x5 0 0 0 0 0

或换句话说,如何有效地计算一个数字,使公共部分保持在最后一个不同位的左边?

template <typename T>
inline T f(const T x, const T y) 
{
    // Something here
}

例如,对于:

x = 10100101
y = 10110010

它应该返回

z = 10100000

注意:它用于超级计算,此操作将执行数十亿次,因此应该避免逐个扫描这些位...

5 个答案:

答案 0 :(得分:6)

我的答案基于@JerryCoffin的答案。

int d = x ^ y;
d = d | (d >> 1);
d = d | (d >> 2);
d = d | (d >> 4);
d = d | (d >> 8);
d = d | (d >> 16);
int z = x & (~d);

答案 1 :(得分:3)

这个问题的一部分在位操作中半定期出现:“带OR的并行后缀”或“前缀”(也就是说,根据你听的人,低位被称为后缀或前缀)。显然,一旦你有办法做到这一点,将它扩展到你想要的东西是微不足道的(如其他答案所示)。

无论如何,显而易见的方法是:

x |= x >> 1
x |= x >> 2
x |= x >> 4
x |= x >> 8
x |= x >> 16

但你可能并不局限于简单的操作符。

对于Haswell,我发现的最快方式是:

lzcnt rax, rax     ; number of leading zeroes, sets carry if rax=0
mov edx, 64
sub edx, eax
mov rax, -1
bzhi rax, rax, rdx ; reset the bits in rax starting at position rdx

其他竞争者是:

mov rdx, -1
bsr rax, rax       ; position of the highest set bit, set Z flag if no bit
cmovz rdx, rax     ; set rdx=rax iff Z flag is set
xor eax, 63
shrx rax, rdx, rax ; rax = rdx >> rax

lzcnt rax, rax
sbb rdx, rdx       ; rdx -= rdx + carry (so 0 if no carry, -1 if carry)
not rdx
shrx rax, rdx, rax

但他们没那么快。

我也考虑了

lzcnt rax, rax
mov rax, [table+rax*8]

但很难公平地比较它,因为它是唯一一个花费缓存空间的东西,它具有非局部效果。

对各种方法进行基准测试导致this question关于lzcnt的一些奇怪行为。

它们都依赖于一些快速的方法来确定最高设置位的位置,如果你真的需要,你可以使用转换为浮动和指数提取,所以可能大多数平台都可以使用类似的东西。

如果移位计数等于或大于操作数大小,则给出零的移位将非常好地解决此问题。 x86没有,但也许你的平台没有。

如果你有一个快速位反转指令,你可以做类似的事情:(这不是ARM asm)

rbit r0, r0
neg r1, r0
or r0, r1, r0
rbit r0, r0

答案 2 :(得分:2)

比较几种算法会导致这种排名:

在下面的测试中有一个1或10的内循环:

  1. 利用内置的位扫描功能。
  2. 用或移位最低有效位(函数) @Egor Skriptunoff)。
  3. 涉及查找表。
  4. 扫描最重要的位(第二个 @Tomas的功能。
  5. InnerLoops = 10:

    Timing 1: 0.101284
    Timing 2: 0.108845
    Timing 3: 0.102526
    Timing 4: 0.191911
    

    100或更大的内循环:

    1. 利用内置的位扫描功能。
    2. 涉及查找表。
    3. 用或移位最低有效位(函数) @Egor Skriptunoff)。
    4. 扫描最重要的位(第二个 @Tomas的功能。
    5. InnerLoops = 100:

      Timing 1: 0.441786
      Timing 2: 0.507651
      Timing 3: 0.548328
      Timing 4: 0.593668
      

      测试:

      #include <algorithm>
      #include <chrono>
      #include <limits>
      #include <iostream>
      #include <iomanip>
      
      // Functions
      // =========
      
      inline unsigned function1(unsigned  a, unsigned b)
      {
          a ^= b;
          if(a) {
              int n = __builtin_clz (a);
              a = (~0u) >> n;
          }
          return ~a & b;
      }
      
      typedef std::uint8_t byte;
      static byte msb_table[256] = {
          0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
          6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
          7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
          7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
          8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
          8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
          8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
          8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
      };
      
      inline unsigned function2(unsigned a, unsigned  b)
      {
          a ^= b;
          if(a) {
              unsigned n = 0;
              if(a >> 24) n = msb_table[byte(a >> 24)] + 24;
              else if(a >> 16) n = msb_table[byte(a >> 16)] + 16;
              else if(a >> 8) n = msb_table[byte(a >> 8)] + 8;
              else n = msb_table[byte(a)];
              a = (~0u) >> (32-n);
          }
          return ~a & b;
      }
      
      inline unsigned function3(unsigned  a, unsigned  b)
      {
          unsigned d = a ^ b;
          d = d | (d >> 1);
          d = d | (d >> 2);
          d = d | (d >> 4);
          d = d | (d >> 8);
          d = d | (d >> 16);
          return a & (~d);;
      }
      
      inline unsigned function4(unsigned  a, unsigned  b)
      {
          const unsigned maxbit = 1u << (std::numeric_limits<unsigned>::digits - 1);
          unsigned msb = maxbit;
          a ^= b;
          while( ! (a & msb))
              msb >>= 1;
          if(msb == maxbit) return 0;
          else {
              msb <<= 1;
              msb  -= 1;
              return ~msb & b;
          }
      }
      
      
      // Test
      // ====
      
      inline double duration(
          std::chrono::system_clock::time_point start,
          std::chrono::system_clock::time_point end)
      {
          return double((end - start).count())
              / std::chrono::system_clock::period::den;
      }
      
      int main() {
          typedef unsigned (*Function)(unsigned , unsigned);
          Function fn[] = {
              function1,
              function2,
              function3,
              function4,
          };
          const unsigned N = sizeof(fn) / sizeof(fn[0]);
          std::chrono::system_clock::duration timing[N] = {};
          const unsigned OuterLoops = 1000000;
          const unsigned InnerLoops = 100;
          const unsigned Samples = OuterLoops * InnerLoops;
          unsigned* A = new unsigned[Samples];
          unsigned* B = new unsigned[Samples];
          for(unsigned i = 0; i < Samples; ++i) {
              A[i] = std::rand();
              B[i] = std::rand();
          }
          unsigned F[N];
          for(unsigned f = 0; f < N; ++f) F[f] = f;
          unsigned result[N];
          for(unsigned i = 0; i < OuterLoops; ++i) {
              std::random_shuffle(F, F + N);
              for(unsigned f = 0; f < N; ++f) {
                  unsigned g = F[f];
                  auto start = std::chrono::system_clock::now();
                  for(unsigned j = 0; j < InnerLoops; ++j) {
                      unsigned index = i + j;
                      unsigned a = A[index];
                      unsigned b = B[index];
                      result[g] = fn[g](a, b);
                  }
                  auto end = std::chrono::system_clock::now();
                  timing[g] += (end-start);
              }
              for(unsigned f = 1; f < N; ++f) {
                  if(result[0] != result[f]) {
                      std::cerr << "Different Results\n" << std::hex;
                      for(unsigned g = 0; g < N; ++g)
                          std::cout << "Result " << g+1 << ": " << result[g] << '\n';
                      exit(-1);
                  }
              }
          }
      
          for(unsigned i = 0; i < N; ++i) {
              std::cout
                  << "Timing " << i+1 << ": "
                  << double(timing[i].count()) / std::chrono::system_clock::period::den
                  << "\n";
          }
      }
      

      <强>编译器:

      g ++ 4.7.2

      <强>设备:

      英特尔®酷睿™i3-2310M CPU @ 2.10GHz×4 7.7 GiB

答案 3 :(得分:1)

您可以将其简化为更容易找到最高设置位(最高1)的问题,这实际上与查找ceil(log 2 X)相同。

unsigned int x, y, c, m;
int b;

c = x ^ y;          // xor : 00010111

// now it comes: b = number of highest set bit in c
// perhaps some special operation or instruction exists for that
b = -1;
while (c) {
    b++;
    c = c >> 1;
}                  // b == 4

m = (1 << (b + 1)) - 1;   // creates a mask: 00011111
return x & ~m;    // x AND NOT M
return y & ~m;    // should return the same result

事实上,如果你可以轻松地计算ceil(log 2 c),那么只需减去1即可得到m,而不需要使用b进行计算上面的循环。

如果您没有这样的功能,那么仅使用基本汇编级操作(位移一位:<<=1>>=1)的简单优化代码将如下所示:

c = x ^ y;        // c == 00010111 (xor)
m = 1;
while (c) {
    m <<= 1; 
    c >>= 1;
}                 // m == 00100000
m--;              // m == 00011111 (mask)
return x & ~m;    // x AND NOT M

这可以编译成非常快的代码,大多数情况下每行一两个机器指令。

答案 4 :(得分:1)

这有点难看,但假设8位输入,你可以这样做:

int x = 0xA5; // 1010 0101
int y = 0xB2; // 1011 0010
unsigned d = x ^ y;

int mask = ~(d | (d >> 1) | (d >> 2) | (d >> 3) | (d >> 4) | (d >> 5) | (d >> 6));

int z = x & mask;

我们首先计算数字的异或,在它们相等时给出0,在它们不同的情况下给出1。对于你的例子,这给出了:

00010111

然后,我们将这个权利和包容性 - 或者它自己转移到7个可能的位位置:

00010111
00001011
00000101
00000010
00000001

这给出了:

00011111

原始数字相等的是0,而不同的是1。然后我们将其反转为:

11100000

然后我们and使用其中一个原始输入(无关紧要):

10100000

...正是我们想要的结果(与简单的x & y不同,它也适用于xy的其他值。

当然,这个可以扩展到任意宽度,但如果你正在处理(比方说)64位数字,那么d | (d>>1) | ... | (d>>63);会有点长笨拙的一面。