从连续的单词序列中提取任意位数的最有效方法是哪种?

时间:2014-12-23 09:53:25

标签: c++ algorithm bit-manipulation simd intrinsics

假设我们有一个std::vector或任何其他序列容器(有时它将是一个双端队列),它存储uint64_t个元素。

现在,让我们将此向量视为size() * 64个连续位的序列。我需要找到给定[begin, end)范围内的位形成的单词,给定end - begin <= 64所以它适合于单词。

我现在的解决方案找到了两个单词,其中的部分将形成结果,并分别掩盖和组合它们。由于我需要尽可能高效,因此我尝试在没有任何if分支的情况下对所有内容进行编码,以免导致分支错误预测,例如,当整个范围适合于单词或当它跨越两个单词时,不采用不同的路径。要做到这一点,我需要对shiftlshiftr函数进行编码,这些函数除了将>><<运算符转换为指定数量之外什么都不做,但是优雅处理n大于64时的情况,否则将是未定义的行为。

另一点是,现在编码的get()函数也适用于空范围,在数学意义上,例如不仅如果开始==结束,而且如果开始&gt;结束,这是调用此函数的主算法所必需的。再一次,我试图在没有简单分支的情况下做到这一点,并且在这种情况下返回零。

然而,同样在查看汇编代码时,所有这些看起来都太复杂了,无法执行这样一个看似简单的任务。此代码在性能关键算法中运行,该算法运行有点太慢。 valgrind告诉我们这个函数被称为2.3亿次,占总执行时间的40%,所以我真的需要让它更快。

那么你能帮我找到一种更简单和/或更有效的方法来完成这项任务吗? 我也不太关心可移植性 。使用x86 SIMD内在函数(SSE3 / 4 / AVX ecc ...)或编译器内置函数的解决方案是可以的,只要我可以使用g++clang编译它们。

我目前的代码包含在下面:

using word_type = uint64_t;
const size_t W = 64;

// Shift right, but without being undefined behaviour if n >= 64
word_type shiftr(word_type val, size_t n)
{
    uint64_t good = n < W;

    return good * (val >> (n * good));
}

// Shift left, but without being undefined behaviour if n >= 64
word_type shiftl(word_type val, size_t n)
{
    uint64_t good = n < W;

    return good * (val << (n * good));
}

// Mask the word preserving only the lower n bits.
word_type lowbits(word_type val, size_t n)
{
    word_type mask = shiftr(word_type(-1), W - n);

    return val & mask;
}

// Struct for return values of locate()
struct range_location_t {
    size_t lindex; // The word where is located the 'begin' position
    size_t hindex; // The word where is located the 'end' position
    size_t lbegin; // The position of 'begin' into its word
    size_t llen;   // The length of the lower part of the word
    size_t hlen;   // The length of the higher part of the word
};

// Locate the one or two words that will make up the result
range_location_t locate(size_t begin, size_t end)
{
    size_t lindex = begin / W;
    size_t hindex = end / W;
    size_t lbegin = begin % W;
    size_t hend   = end % W;

    size_t len = (end - begin) * size_t(begin <= end);
    size_t hlen = hend * (hindex > lindex);
    size_t llen = len - hlen;

    return { lindex, hindex, lbegin, llen, hlen };
}

// Main function.
template<typename Container>
word_type get(Container const&container, size_t begin, size_t end)
{
    assert(begin < container.size() * W);
    assert(end <= container.size() * W);

    range_location_t loc = locate(begin, end);

    word_type low = lowbits(container[loc.lindex] >> loc.lbegin, loc.llen);

    word_type high = shiftl(lowbits(container[loc.hindex], loc.hlen), loc.llen);

    return high | low;
}

非常感谢。

2 个答案:

答案 0 :(得分:4)

这取代了get()和get()使用的所有辅助函数。它包含一个条件分支并保存大约16个算术运算,这意味着它通常应该运行得更快。通过一些优化编译,它也产生非常短的代码。最后,它解决了在case end == container.size()* W中访问容器[container.size()]的错误。

最棘手的部分是&#34; hi-(hi&gt; 0)&#34;除非hi为0,否则从hi中减去1.减去1不会改变任何东西,除非hi指向单词边界,即hi%64 == 0。在这种情况下,我们需要来自上部容器条目的0位,因此仅使用较低的容器条目。通过在计算hi_off之前减去1,我们确保条件&#34; hi_off == lo_off&#34;,并且我们遇到更简单的情况。

在那个更简单的情况下,我们只需要一个容器入口,并在两侧切掉一些位。 hi_val就是那个条目,高位已经被删除了,所以剩下要做的就是删除一些低位。

在不太简单的情况下,我们也必须读取下面的容器条目,去除它的未使用的字节,并组合两个条目。

namespace {
  size_t   const upper_mask = ~(size_t)0u << 6u;
  unsigned const lower_mask = (unsigned)~upper_mask;
}

word_type get ( Container const &container, size_t lo, size_t hi )
{
  size_t lo_off = lo       >>6u;  assert ( lo_off < container.size() );
  size_t hi_off = hi-(hi>0)>>6u;  assert ( hi_off < container.size() );
  unsigned hi_shift = lower_mask&(unsigned)(upper_mask-hi);
  word_type hi_val = container[hi_off] << hi_shift >> hi_shift;
  unsigned lo_shift = lower_mask&(unsigned)lo;
  if ( hi_off == lo_off ) return hi_val >> lo_shift; // use hi_val as lower word
  return ( hi_val<<W-lo_shift | container[lo_off]>>lo_shift ) * (lo_off<hi_off);
}

答案 1 :(得分:3)

正如聊天中所宣布的,我添加了一个精确的答案。它包含三个部分,每个部分后面都有该部分的描述。

第一部分,get.h,是我的解决方案,但是已经推广并进行了一次修正。

第二部分got.h是问题中公布的原始算法,也适用于任何未签名类型的任何STL容器。

第三部分main.cpp包含单元测试,用于验证正确性和测量性能。

#include <cstddef>

using std::size_t;

template < typename C >
typename C::value_type get ( C const &container, size_t lo, size_t hi )
{

   typedef typename C::value_type item; // a container entry
   static unsigned const bits = (unsigned)sizeof(item)*8u; // bits in an item
   static size_t const mask = ~(size_t)0u/bits*bits; // huge multiple of bits

   // everthing above has been computed at compile time. Now do some work:

   size_t lo_adr = (lo       ) / bits; // the index in the container of ...
   size_t hi_adr = (hi-(hi>0)) / bits; // ... the lower or higher item needed

   // we read container[hi_adr] first and possibly delete the highest bits:

   unsigned hi_shift = (unsigned)(mask-hi)%bits;
   item hi_val = container[hi_adr] << hi_shift >> hi_shift;

   // if all bits are in the same item, we delete the lower bits and are done:

   unsigned lo_shift = (unsigned)lo%bits;
   if ( hi_adr <= lo_adr ) return (hi_val>>lo_shift) * (lo<hi);

   // else we have to read the lower item as well, and combine both

   return ( hi_val<<bits-lo_shift | container[lo_adr]>>lo_shift );

}

第一部分,上面的get.h,是我原来的解决方案,但是可以推广使用任何无符号整数类型的STL容器。因此,您也可以使用和测试32位整数或128位整数。我仍然使用unsigned作为非常小的数字,但你也可以用size_t替换它们。该算法几乎没有变化,只有一个小的修正 - 如果lo是容器中的总位数,我之前的get()将访问容器大小正上方的项目。现在已修复。

#include <cstddef>

using std::size_t;

// Shift right, but without being undefined behaviour if n >= 64
template < typename val_type >
val_type shiftr(val_type val, size_t n)
{
   val_type good = n < sizeof(val_type)*8;
   return good * (val >> (n * good));
}

// Shift left, but without being undefined behaviour if n >= 64
template < typename val_type >
val_type shiftl(val_type val, size_t n)
{
   val_type good = n < sizeof(val_type)*8;
   return good * (val << (n * good));
}

// Mask the word preserving only the lower n bits.
template < typename val_type >
val_type lowbits(val_type val, size_t n)
{
    val_type mask = shiftr<val_type>((val_type)(-1), sizeof(val_type)*8 - n);
    return val & mask;
}

// Struct for return values of locate()
struct range_location_t {
   size_t lindex; // The word where is located the 'begin' position
   size_t hindex; // The word where is located the 'end' position
   size_t lbegin; // The position of 'begin' into its word
   size_t llen;   // The length of the lower part of the word
   size_t hlen;   // The length of the higher part of the word
};

// Locate the one or two words that will make up the result
template < typename val_type >
range_location_t locate(size_t begin, size_t end)
{
   size_t lindex = begin / (sizeof(val_type)*8);
   size_t hindex = end / (sizeof(val_type)*8);
   size_t lbegin = begin % (sizeof(val_type)*8);
   size_t hend   = end % (sizeof(val_type)*8);

   size_t len = (end - begin) * size_t(begin <= end);
   size_t hlen = hend * (hindex > lindex);
   size_t llen = len - hlen;

   range_location_t l = { lindex, hindex, lbegin, llen, hlen };
   return l;
}

// Main function.
template < typename C >
typename C::value_type got ( C const&container, size_t begin, size_t end )
{
   typedef typename C::value_type val_type;
   range_location_t loc = locate<val_type>(begin, end);
   val_type low = lowbits<val_type>(container[loc.lindex] >> loc.lbegin, loc.llen);
   val_type high = shiftl<val_type>(lowbits<val_type>(container[loc.hindex], loc.hlen), loc.llen);
   return high | low;
}

第二部分,上面的got.h,是问题中的原始算法,也可以广义接受任何无符号整数类型的STL容器。与get.h一样,除了定义容器类型的单个模板参数外,此版本不使用任何定义,因此可以轻松地测试其他项目大小或容器类型。

#include <vector>
#include <cstddef>
#include <stdint.h>
#include <stdio.h>
#include <sys/time.h>
#include <sys/resource.h>
#include "get.h"
#include "got.h"

template < typename Container > class Test {

   typedef typename Container::value_type val_type;
   typedef val_type (*fun_type) ( Container const &, size_t, size_t );
   typedef void (Test::*fun_test) ( unsigned, unsigned );
   static unsigned const total_bits = 256; // number of bits in the container
   static unsigned const entry_bits = (unsigned)sizeof(val_type)*8u;

   Container _container;
   fun_type _function;
   bool _failed;

   void get_value ( unsigned lo, unsigned hi ) {
      _function(_container,lo,hi); // we call this several times ...
      _function(_container,lo,hi); // ... because we measure ...
      _function(_container,lo,hi); // ... the performance ...
      _function(_container,lo,hi); // ... of _function, ....
      _function(_container,lo,hi); // ... not the performance ...
      _function(_container,lo,hi); // ... of get_value and ...
      _function(_container,lo,hi); // ... of the loop that ...
      _function(_container,lo,hi); // ... calls get_value.
   }

   void verify ( unsigned lo, unsigned hi ) {
      val_type value = _function(_container,lo,hi);
      if ( lo < hi ) {
         for ( unsigned i=lo; i<hi; i++ ) {
            val_type val = _container[i/entry_bits] >> i%entry_bits & 1u;
            if ( val != (value&1u) ) {
               printf("lo=%d hi=%d [%d] is'nt %d\n",lo,hi,i,(unsigned)val);
               _failed = true;
            }
            value >>= 1u;
         }
      }
      if ( value ) {
         printf("lo=%d hi=%d value contains high bits set to 1\n",lo,hi);
         _failed = true;
      }
   }

   void run ( fun_test fun ) {
      for ( unsigned lo=0; lo<total_bits; lo++ ) {
         unsigned h0 = 0;
         if ( lo > entry_bits ) h0 = lo - (entry_bits+1);
         unsigned h1 = lo+64;
         if ( h1 > total_bits ) h1 = total_bits;
         for ( unsigned hi=h0; hi<=h1; hi++ ) {
            (this->*fun)(lo,hi);
         }
      }
   }

   static uint64_t time_used ( ) {
      struct rusage ru;
      getrusage(RUSAGE_THREAD,&ru);
      struct timeval t = ru.ru_utime;
      return (uint64_t) t.tv_sec*1000 + t.tv_usec/1000;
   }

public:

   Test ( fun_type function ): _function(function), _failed() {
      val_type entry;
      unsigned index = 0; // position in the whole bit array
      unsigned value = 0; // last value assigned to a bit
      static char const entropy[] = "The quick brown Fox jumps over the lazy Dog";
      do {
         if ( ! (index%entry_bits) ) entry = 0;
         entry <<= 1;
         entry |= value ^= 1u & entropy[index/7%sizeof(entropy)] >> index%7;
         ++index;
         if ( ! (index%entry_bits) ) _container.push_back(entry);
      } while ( index < total_bits );
   }

   bool correctness() {
      _failed = false;
      run(&Test::verify);
      return !_failed;
   }

   void performance() {
      uint64_t t1 = time_used();
      for ( unsigned i=0; i<999; i++ ) run(&Test::get_value);
      uint64_t t2 = time_used();
      printf("used %d ms\n",(unsigned)(t2-t1));
   }

   void operator() ( char const * name ) {
      printf("testing %s\n",name);
      correctness();
      performance();
   }

};

int main()
{
   typedef typename std::vector<uint64_t> Container;
   Test<Container> test(get<Container>); test("get");
   Test<Container> tost(got<Container>); tost("got");
}

第三部分,main.cpp,包含一类单元测试,并将它们应用于get.h和got.h,即我的解决方案和问题的原始代码,稍加修改。单元测试验证正确性和测量速度。它们通过创建一个256位的容器来验证正确性,用一些数据填充它,读取所有可能的位部分,最多可以容纳到容器条目和许多病理情况中的位,并验证每个结果的正确性。他们通过再次阅读相同的部分并报告用户空间中使用的线程时间来测量速度。