如何得到推力指数

时间:2012-08-05 05:20:21

标签: cuda thrust

我正在尝试使用推力为每个设备向量提供某些值 这是代码

const uint N = 222222; 
struct assign_functor
{
  template <typename Tuple>
  __device__ 
  void operator()(Tuple t)
  {  
    uint x = threadIdx.x + blockIdx.x * blockDim.x;
    uint y = threadIdx.y + blockIdx.y * blockDim.y;
    uint offset = x + y * blockDim.x * gridDim.x; 

    thrust::get<0>(t) = offset; 
  }
};
int main(int argc, char** argv)
{ 

  thrust::device_vector <float> d_float_vec(N);  

  thrust::for_each(
    thrust::make_zip_iterator( 
      thrust::make_tuple(d_float_vec.begin()) 
    ), 
    thrust::make_zip_iterator( 
      thrust::make_tuple(d_float_vec.end())
    ), 
    assign_functor()
  );

  std::cout<<d_float_vec[10]<<" "<<d_float_vec[N-2] 
}

d_float_vec [N-2]的输出应为22​​2220;但结果是1036.我的代码有什么问题?

我知道我可以使用thrust :: sequence为向量提供序列值。我只是想知道如何获得推力foreach函数的真实指数。谢谢!

1 个答案:

答案 0 :(得分:2)

正如评论中所指出的那样,你的方法永远不可能奏效,因为你已经假设了很多关于thrust::for_each内部工作方式的事情,这可能不正确,包括:

  • 您隐含地假设for_each使用单个线程来处理每个输入元素。几乎可以肯定不是这样;在操作过程中,推力将更有可能每个线程处理多个元素。
  • 您还假设执行按顺序执行,以便第N个线程处理第N个数组元素。情况可能并非如此,并且执行可能以无法知道的顺序发生 a priori
  • 您假设for_each处理单个内​​核laumch中的整个输入数据集

应将Thrust算法视为黑框,其内部操作未定义,并且不需要了解它们来实现用户定义的仿函数。在您的示例中,如果在仿函数中需要顺序索引,则传递计数迭代器。重写你的例子的一种方法是:

#include "thrust/device_vector.h"
#include "thrust/for_each.h"
#include "thrust/tuple.h"
#include "thrust/iterator/counting_iterator.h"

typedef unsigned int uint;
const uint N = 222222; 
struct assign_functor
{
  template <typename Tuple>
  __device__ 
  void operator()(Tuple t)
  {  
    thrust::get<1>(t) = (float)thrust::get<0>(t);
  }
};

int main(int argc, char** argv)
{ 
  thrust::device_vector <float> d_float_vec(N);  
  thrust::counting_iterator<uint> first(0);
  thrust::counting_iterator<uint> last = first + N;

  thrust::for_each(
    thrust::make_zip_iterator( 
      thrust::make_tuple(first, d_float_vec.begin()) 
    ), 
    thrust::make_zip_iterator( 
      thrust::make_tuple(last, d_float_vec.end())
    ), 
    assign_functor()
  );

  std::cout<<d_float_vec[10]<<" "<<d_float_vec[N-2]<<std::endl; 
}

这里计数迭代器与数据数组一起在元组中传递,允许函数访问与它正在处理的数据数组条目相对应的顺序索引。