如何在CUDA上部分排序数组?

时间:2013-05-15 02:12:31

标签: sorting cuda thrust

问题

如果我有两个数组:

   const int N = 1000000;
   float A[N];
   myStruct *B[N];

A中的数字可以是正数或负数(例如A[N]={3,2,-1,0,5,-2}),如何使数组A 部分排序(所有正值首先,不需要排序,然后是负值)< / strong>(例如A[N]={3,2,5,0,-1,-2}A[N]={5,2,3,0,-2,-1})在GPU上?阵列B应根据A改变(A是键,B是值)。

由于A,B的规模可能非常大,我认为排序算法应该在GPU上实现(特别是在CUDA上,因为我使用这个平台)。当然我知道thrust::sort_by_key可以完成这项工作,但它确实可以解决额外的工作,因为我不需要完全排序数组A&amp; B。

有没有人遇到过这种问题?

推力示例

thrust::sort_by_key(thrust::device_ptr<float> (A), 
            thrust::device_ptr<float> ( A + N ),  
            thrust::device_ptr<myStruct> ( B ),  
            thrust::greater<float>() );

3 个答案:

答案 0 :(得分:1)

Thrust在Github上的文档并不是最新的。正如@JaredHoberock所说,thrust::partition是现在的发展方式supports stencils。您可能需要从Github repository获取副本:

  

git clone git://github.com/thrust/thrust.git

然后在Thrust文件夹中运行scons doc以获取更新的文档,并在编译代码时使用这些更新的Thrust源(nvcc -I/path/to/thrust ...)。使用新的模板分区,您可以执行以下操作:

#include <thrust/partition.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/tuple.h>

struct is_positive
{
__host__ __device__
bool operator()(const int &x)
{
  return x >= 0;
}
};


thrust::partition(thrust::host, // if you want to test on the host
                  thrust::make_zip_iterator(thrust::make_tuple(keyVec.begin(), valVec.begin())),
                  thrust::make_zip_iterator(thrust::make_tuple(keyVec.end(), valVec.end())),
                  keyVec.begin(),
                  is_positive());

返回:

Before:
  keyVec =   0  -1   2  -3   4  -5   6  -7   8  -9
  valVec =   0   1   2   3   4   5   6   7   8   9
After:
  keyVec =   0   2   4   6   8  -5  -3  -7  -1  -9
  valVec =   0   2   4   6   8   5   3   7   1   9

请注意,2个分区不一定要排序。此外,原始矢量和分区之间的顺序可能不同。如果这对您很重要,您可以使用thrust::stable_partition

  

stable_partition与partition的区别在于stable_partition是   保证保持相对顺序。也就是说,如果x和y是   [first,last]中的元素,如pred(x)== pred(y),如果是x   在y之前,然后在stable_partition之后它仍然是真的x   在y之前。

如果你想要一个完整的例子,那就是:

#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/partition.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/tuple.h>

struct is_positive
{
__host__ __device__
bool operator()(const int &x)
{
  return x >= 0;
}
};

void print_vec(const thrust::host_vector<int>& v)
{
  for(size_t i = 0; i < v.size(); i++)
    std::cout << "  " << v[i];
  std::cout << "\n";
}

int main ()
{
  const int N = 10;

  thrust::host_vector<int> keyVec(N);
  thrust::host_vector<int> valVec(N);

  int sign = 1;
  for(int i = 0; i < N; ++i)
  {
    keyVec[i] = sign * i;
    valVec[i] = i;
    sign *= -1;
  }

  // Copy host to device
  thrust::device_vector<int> d_keyVec = keyVec;
  thrust::device_vector<int> d_valVec = valVec;

  std::cout << "Before:\n  keyVec = ";
  print_vec(keyVec);
  std::cout << "  valVec = ";
  print_vec(valVec);

  // Partition key-val on device
  thrust::partition(thrust::make_zip_iterator(thrust::make_tuple(d_keyVec.begin(), d_valVec.begin())),
                    thrust::make_zip_iterator(thrust::make_tuple(d_keyVec.end(), d_valVec.end())),
                    d_keyVec.begin(),
                    is_positive());

  // Copy result back to host
  keyVec = d_keyVec;
  valVec = d_valVec;

  std::cout << "After:\n  keyVec = ";
  print_vec(keyVec);
  std::cout << "  valVec = ";
  print_vec(valVec);
}

更新

我与thrust::sort_by_key版本进行了快速比较,thrust::partition实现似乎更快(这是我们自然可以预期的)。以下是我在NVIDIA Visual Profiler上使用N = 1024 * 1024获得的内容,左侧是排序版,右侧是分区版。您可能希望自己进行相同类型的测试。

Sort vs Partition

答案 1 :(得分:0)

这个怎么样?:

  1. 计算确定拐点的正数是多少
  2. 将拐点的每一侧均匀分成组(负组的长度相同但长度与正组不同。这些组是结果的记忆块)
  3. 每个块对使用一个内核调用(一个线程)
  4. 每个内核将输入组中的任何不合适的元素交换为所需的输出组。您需要标记任何具有比最大值更多的交换的块,以便您可以在后续迭代期间修复它们。
  5. 重复完成
  6. 内存流量仅为交换(从原始元素位置到排序位置)。我不知道这个算法是否听起来像任何已定义的......

答案 2 :(得分:0)

只需修改比较运算符,您就可以实现这一目标:

struct my_compare
{
  __device__ __host__ bool operator()(const float x, const float y) const
  {
    return !((x<0.0f) && (y>0.0f));
  }
};


thrust::sort_by_key(thrust::device_ptr<float> (A), 
            thrust::device_ptr<float> ( A + N ),  
            thrust::device_ptr<myStruct> ( B ),  
            my_compare() );