如何编写无分支std :: vector扫描?

时间:2016-08-05 23:01:30

标签: c++ arrays vector conditional-statements

我想在数组上编写一个简单的扫描。我有一个std::vector<int> data,我想找到元素小于9的所有数组索引,并将它们添加到结果向量中。我可以用分支写这个:

for (int i = 0; i < data.size(); ++i)
    if (data[i] < 9)
        r.push_back(i);

这给出了正确答案,但我想将其与无分支版本进行比较。

使用原始数组 - 假设data是一个int数组,length是其中的元素数,r是一个有足够空间的结果数组 - 我可以写点如下:

int current_write_point = 0;
for (int i = 0; i < length; ++i){
    r[current_write_point] = i;
    current_write_point += (data[i] < 9);
}

如何使用data的矢量获得类似的行为?

3 个答案:

答案 0 :(得分:6)

让我们看看实际的compiler output

auto scan_branch(const std::vector<int>& v)
{
  std::vector<int> res;
  int insert_index = 0;
  for(int i = 0; i < v.size(); ++i)
  {
    if (v[i] < 9)
    {
       res.push_back(i);
    } 
  }
  return res;
}

此代码显然在disassembly的第26行有一个分支。如果它大于或等于9,它只会继续下一个元素,但是如果小于9,则会为push_back执行一些可怕的代码,我们继续。没什么意外的。

auto scan_nobranch(const std::vector<int>& v)
{
  std::vector<int> res;
  res.resize(v.size());

  int insert_index = 0;
  for(int i = 0; i < v.size(); ++i)
  {
    res[insert_index] = i;
    insert_index += v[i] < 9;
  }

  res.resize(insert_index);
  return res;
}

然而,这个只有一个条件移动,你可以在disassembly的第190行看到。看起来我们有一个胜利者。由于条件移动不会导致流水线停滞,因此在此流程中没有分支(条件检查除外)。

答案 1 :(得分:-1)

std::copy_if(std::begin(data), std::end(data), std::back_inserter(r));

答案 2 :(得分:-2)

好吧,您可以事先调整矢量大小并保留算法:

// Resize the vector so you can index it normally
r.resize(length);

// Do your algorithm like before
int current_write_point = 0;
for (int i = 0; i < length; ++i){
    r[current_write_point] = i;
    current_write_point += (data[i] < 9);
}

// Afterwards, current_write_point can be used to shrink the vector, so
// there are no excess elements not written to
r.resize(current_write_point + 1);

如果你不想要比较,你可以使用一些按位和带有短路的布尔操作来确定。

首先,我们知道所有负整数都小于9.其次,如果它是正数,我们可以使用位掩码来确定整数是否在0-15范围内(实际上,我们&#39; ll检查它是否不在该范围内,因此大于15)。然后,我们知道如果从该数字中减去8的结果是负数,则结果小于9: 实际上,我只是想出了一个更好的方法。由于我们可以轻松确定是否x < 0,因此我们可以将x减去9以确定是否x < 9

#include <iostream>

// Use bitwise operations to determine if x is negative
int n(int x) {
    return x & (1 << 31);
}

int main() {
    int current_write_point = 0;
    for (int i = 0; i < length; ++i){
        r[current_write_point] = i;
        current_write_point += n(data[i] - 9);
    }
}