使用Thrust

时间:2016-11-08 18:09:59

标签: c++11 cuda functor thrust

我尝试使用Thrust创建设备仿函数,它将对设备数据结构的引用存储为其状态。然后,仿函数将传递给thrust::transform()和朋友。问题是我在仿函数的return语句中从设备代码调用主机函数时遇到错误:

// Compile with:
// nvcc --std=c++11 device_functor.cu -o device_functor

#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/complex.h>

#include <iostream>
#include <iomanip>

struct my_functor {
    my_functor(thrust::device_vector<unsigned char> &octets_) :
        octets(octets_) {};

    __device__
    unsigned char operator()(int idx) const {
        return octets[idx];
    }
private:
    thrust::device_vector<unsigned char> &octets;
};


int main() {
    thrust::device_vector<unsigned char> d_octets (4);

    my_functor foo(d_octets);

    d_octets[0] = 0x00;
    d_octets[1] = 0x01;
    d_octets[2] = 0x02;
    d_octets[3] = 0x03;

    std::cout << "0x" << std::hex << std::setfill('0') << std::setw(2) << static_cast <int> (foo(2)) << std::endl;

    return 0;
}

这样做的最终目标之一是在转换中以各种方式访问​​octets中的位,例如抓取第五组三位,第十组四位等等。&#39一旦我可以让仿函数工作,这一切都很容易。

1 个答案:

答案 0 :(得分:1)

重新编写仿函数以解决@talonmies的评论似乎可以解决问题:

struct my_functor {
    my_functor(thrust::device_vector<unsigned char> &octets) :
        octet_ptr(thrust::raw_pointer_cast(&octets[0])) {};

    __device__
    unsigned char operator()(int idx) const {
        return *(octet_ptr + idx);
    }

private:
    unsigned char *octet_ptr;
};