我使用CUDA和THRUST来执行配对设置操作。但是,我想保留重复项。例如:
int keys[6] = {1, 1, 1, 3, 4, 5, 5};
int vals[6] = {1, 2, 3, 4, 5, 6, 7};
int comp[2] = {1, 5};
thrust::set_intersection_by_key(keys, keys + 6, comp, comp + 2, vals, rk, rv);
期望的结果
rk[1, 1, 1, 5, 5]
rv[1, 2, 3, 6, 7]
实际结果
rk[1, 5]
rv[5, 7]
我想要所有 vals , comp 中包含相应的键。
有没有办法用推力实现这个目标,还是我必须编写自己的内核或推力函数?
我正在使用此功能:set_intersection_by_key。
答案 0 :(得分:3)
概括是如果一个元素在[keys_first1,keys_last1]中出现m次,在[keys_first2,keys_last2)中出现n次(其中m可能为零),那么它在键输出中出现min(m,n)次范围
由于comp
仅包含每个密钥一次,n=1
因此min(m,1) = 1
。
为了获得 comp 中包含相应键的所有 vals “,您可以使用{的方法{3}}
类似地,示例代码执行以下步骤:
获取d_comp
的最大元素。这假定d_comp
已经排序。
创建大小为d_map
的向量largest_element+1
。将1
复制到d_comp
中d_map
条目的所有位置。
将d_vals
中1
条目d_map
的所有条目复制到d_result
。
#include <thrust/device_vector.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/functional.h>
#include <thrust/copy.h>
#include <thrust/scatter.h>
#include <iostream>
#define PRINTER(name) print(#name, (name))
void print(const char* name, const thrust::device_vector<int>& v)
{
std::cout << name << ":\t";
thrust::copy(v.begin(), v.end(), std::ostream_iterator<int>(std::cout, "\t"));
std::cout << std::endl;
}
int main()
{
int keys[] = {1, 1, 1, 3, 4, 5, 5};
int vals[] = {1, 2, 3, 4, 5, 6, 7};
int comp[] = {1, 5};
const int size_data = sizeof(keys)/sizeof(keys[0]);
const int size_comp = sizeof(comp)/sizeof(comp[0]);
// copy data to GPU
thrust::device_vector<int> d_keys (keys, keys+size_data);
thrust::device_vector<int> d_vals (vals, vals+size_data);
thrust::device_vector<int> d_comp (comp, comp+size_comp);
PRINTER(d_keys);
PRINTER(d_vals);
PRINTER(d_comp);
int largest_element = d_comp.back();
thrust::device_vector<int> d_map(largest_element+1);
thrust::constant_iterator<int> one(1);
thrust::scatter(one, one+size_comp, d_comp.begin(), d_map.begin());
PRINTER(d_map);
thrust::device_vector<int> d_result(size_data);
using namespace thrust::placeholders;
int final_size = thrust::copy_if(d_vals.begin(),
d_vals.end(),
thrust::make_permutation_iterator(d_map.begin(), d_keys.begin()),
d_result.begin(),
_1
) - d_result.begin();
d_result.resize(final_size);
PRINTER(d_result);
return 0;
}
<强>输出强>:
d_keys: 1 1 1 3 4 5 5
d_vals: 1 2 3 4 5 6 7
d_comp: 1 5
d_map: 0 1 0 0 0 1
d_result: 1 2 3 6 7