nVidia Thrust:device_ptr Const-Correctness

时间:2015-11-24 11:44:50

标签: c++ cuda thrust

在我广泛使用nVidia CUDA的项目中,我有时会将Thrust用于非常非常好的事情。 Reduce 是一种在该库中特别好实现的算法, reduce 的一个用途是通过将每个元素除以所有元素的总和来规范化非负元素的向量元件。

template <typename T>
void normalise(T const* const d_input, const unsigned int size, T* d_output)
{
    const thrust::device_ptr<T> X = thrust::device_pointer_cast(const_cast<T*>(d_input));
    T sum = thrust::reduce(X, X + size);

    thrust::constant_iterator<T> denominator(sum);
    thrust::device_ptr<T> Y = thrust::device_pointer_cast(d_output);
    thrust::transform(X, X + size, denominator, Y, thrust::divides<T>());
}

T通常为floatdouble

一般来说,我不希望在整个代码库中依赖Thrust,因此我尝试确保上述示例之类的函数只接受原始CUDA设备指针。这意味着一旦它们由NVCC编译,我就可以将它们静态地链接到没有NVCC的其他代码中。

然而,这段代码让我很担心。我希望函数是const-correct但我似乎无法找到const版本的thrust::device_pointer_cast(...) - 这样的事情是否存在?在此版本的代码中,我使用了const_cast,因此我在函数签名中使用了const,这让我很难过。

另一方面,将 reduce 的结果复制到主机只是为了将其发送回设备以进行下一步操作感觉很奇怪。有更好的方法吗?

1 个答案:

答案 0 :(得分:5)

如果你想要const-correctness,你需要在任何地方都是const-correct。 input是指向const T的指针,因此应该是X

const thrust::device_ptr<const T> X = thrust::device_pointer_cast(d_input);