通过构造函数将device_vector存储在函子中吗?

时间:2018-07-11 18:43:59

标签: cuda functor thrust

我正在尝试将thrust::device_vector存储在函子中。简单的解释如下:

struct StructOperator : public thrust::unary_function<float, int>  {
  int num_;
  thrust::device_vector<int> v_test;

  explicit StructOperator(thrust::device_vector<int> const& input_v) :
    v_test(input_v), num_(input_v.size()) {};

  __host__ __device__
   float operator()(int index) {
      // magic happens
   }
};

无法编译-nvcc一直说不允许从__host__调用__host__ __device__。我见过this问题-这是实现这一目标的唯一方法吗?

1 个答案:

答案 0 :(得分:4)

现在将__device__装饰器放在functor运算符上时,现在您只能在该运算符主体中执行的操作仅限于与CUDA设备代码兼容的事物。

thrust::device_vector是一个类定义,旨在促进推力的表达/计算模型(大致类似于STL容器/算法)。因此,它在其中包含主机和设备代码。 thrust::device_vector中的主机代码未经过修饰,无法在设备上使用,并且普通主机代码不可在CUDA设备代码中使用。

thrust::device_vector既不设计也不打算直接在设备代码中使用。不能按照您的建议使用它。与推测的结果相反,它并非设计为可在CUDA设备代码中使用的std::vector的类似物。它被设计为std::vector的类似物,可用于推力算法(根据设计,可从主机代码调用/使用)。这就是为什么您在编译时收到消息的原因,并且没有简单的方法(*)来解决。

大概thrust::device_vector的主要目的是充当容器来保存设备上可用/可访问的数据。 CUDA设备代码中已支持的POD类型数据中最直接的等效项是数组或数据指针。

因此,我认为以“是”回答您的问题是合理的-这是实现此目标的唯一方法。

  • 我正在尝试各种类似的方法,例如传递推力指针而不是裸指针。
  • (*)我忽略了这样的想法,例如编写自己的容器类以允许在设备上使用,或进行大量修改以使自身以某种方式允许这种行为。

这是一个完全有效的示例,围绕您所显示的内容:

$ cat t1385.cu
#include <iostream>
#include <thrust/device_vector.h>
#include <thrust/transform.h>
#include <thrust/copy.h>


struct StructOperator : public thrust::unary_function<float, int>  {
  int num_;
  int *v_test;

  explicit StructOperator(int *input_v, int input_v_size) :
    v_test(input_v), num_(input_v_size) {};

  __host__ __device__
   float operator()(int index) {
      if (index < num_)  return v_test[index] + 0.5;
      return 0.0f;
   }
};

const int ds = 3;
int main(){

  thrust::device_vector<int> d(ds);
  thrust::sequence(d.begin(), d.end());
  thrust::device_vector<float> r(ds);
  thrust::transform(d.begin(), d.end(), r.begin(), StructOperator(thrust::raw_pointer_cast(d.data()), d.size()));
  thrust::copy(r.begin(), r.end(), std::ostream_iterator<float>(std::cout, ","));
  std::cout << std::endl;
}
$ nvcc t1385.cu -o t1385
$ ./t1385
0.5,1.5,2.5,
$