如何使用占位符组合推力比较?

时间:2018-12-10 13:38:29

标签: cuda comparison thrust indices

我有一个大小为N的浮点值A的device_vector。我也有一个浮点值V供比较。根据输入值,我需要提取A的索引,其值>>

我使用以下代码,但看起来很麻烦。有更简洁的方法吗?

void detect_indices_lesser_greater_equal_to_value(thrust::device_vector<float> S, float value, 
                        int criterion, thrust::device_vector<int>& indices)
{

int N=S.size();

int size=N;


if(criterion==0) // criterion =0 => equal
{
thrust::device_vector<int>::iterator end = thrust::copy_if(thrust::device,thrust::make_counting_iterator(0),
                                                             thrust::make_counting_iterator(N),
                                                             S.begin(),
                                                             indices.begin(), 
                                                             thrust::placeholders::_1 == value);
size = end-indices.begin();
}


if(criterion==1) // criterion =1 => less
{
thrust::device_vector<int>::iterator end = thrust::copy_if(thrust::device,thrust::make_counting_iterator(0),
                                                             thrust::make_counting_iterator(N),
                                                             S.begin(),
                                                             indices.begin(), 
                                                             thrust::placeholders::_1 < value);
size = end-indices.begin();
}

if(criterion==2) // criterion =2 => greater
{
thrust::device_vector<int>::iterator end = thrust::copy_if(thrust::device,thrust::make_counting_iterator(0),
                                                             thrust::make_counting_iterator(N),
                                                             S.begin(),
                                                             indices.begin(), 
                                                             thrust::placeholders::_1 > value);
size = end-indices.begin();
}

indices.resize(size);

}

1 个答案:

答案 0 :(得分:2)

这可以通过两个thrust::partition操作来完成。分区非常简单:将所有导致真实谓词的内容移至输入向量的左侧。其他所有内容都移到右侧。这是一个简单的示例:

$ cat t22.cu
#include <thrust/partition.h>
#include <thrust/copy.h>
#include <thrust/device_vector.h>

typedef float mt;

using namespace thrust::placeholders;
int main(){

  const mt pval = 4;
  mt data[] = {1,3,7,4,5,2,4,3,9};
  const int ds = sizeof(data)/sizeof(data[0]);
  thrust::device_vector<mt> d(data, data+ds);
  auto end1 = thrust::partition(d.begin(), d.end(), _1<pval);
  auto end2 = thrust::partition(end1, d.end(), _1==pval);
  std::cout << "less than pval:" << std::endl;
  thrust::copy(d.begin(), end1, std::ostream_iterator<mt>(std::cout,","));
  std::cout << std::endl << "equal to pval:" << std::endl;
  thrust::copy(end1, end2, std::ostream_iterator<mt>(std::cout,","));
  std::cout << std::endl << "greater than pval:" << std::endl;
  thrust::copy(end2, d.end(), std::ostream_iterator<mt>(std::cout,","));
  std::cout << std::endl;
}
$ nvcc -o t22 t22.cu
$ ./t22
less than pval:
1,3,2,3,
equal to pval:
4,4,
greater than pval:
7,5,9,
$

如果您要求3个结果子向量中的顺序与原始输入顺序相同,则可以使用thrust::stable_partition变体。

(请注意,在您的问题中,您引用的是float个数量,但是您的示例代码使用了<int>迭代器。但是,上述代码可以通过修改typedef来使用)。 / p>