使用模板的CUDA设备阵列实现

时间:2018-07-19 06:46:00

标签: c++ templates cuda

我正在尝试实现Thrust设备矢量的固定大小版本。我编写了一些初始版本,但出现了一个奇怪的模板错误。

代码如下:

#include <iostream>
#include <array>

enum class memcpy_t {
    host_to_host,
    host_to_device,
    device_to_host,
    device_to_device
};

template <typename T, std::size_t N>
struct cuda_allocator {
    using pointer = T*;

    static void allocate(T *dev_mem) {
        cudaMalloc(&dev_mem, N * sizeof(T)); 
    }

    static void deallocate(T *dev_mem) {
        cudaFree(dev_mem); 
    }

    template <memcpy_t ct>
    static void copy (T *dst, T *src) {
        switch(ct) {
        case memcpy_t::host_to_host:
            cudaMemcpy(dst, src, N * sizeof(T), cudaMemcpyHostToHost);
            break;
        case memcpy_t::host_to_device:
            cudaMemcpy(dst, src, N * sizeof(T), cudaMemcpyHostToDevice);
            break;
        case memcpy_t::device_to_host:
            cudaMemcpy(dst, src, N * sizeof(T), cudaMemcpyDeviceToHost);
            break;
        case memcpy_t::device_to_device:
            cudaMemcpy(dst, src, N * sizeof(T), cudaMemcpyDeviceToDevice);
            break;
        default:
            break;
        }
    }
};

template <typename T, std::size_t N>
struct gpu_array {
    using allocator = cuda_allocator<T, N>;
    using pointer = typename allocator::pointer;
    using value_type = T;
    using iterator = T*;
    using const_iterator = T const*;

    gpu_array() {
        allocator::allocate(data);
    }

    gpu_array(std::array<T, N> host_arr) {
        allocator::allocate(data);
        allocator::copy<memcpy_t::host_to_device>(data, host_arr.begin());
    }

    gpu_array& operator=(gpu_array const& o) {
        allocator::allocate(data);
        allocator::copy<memcpy_t::device_to_device>(data, o.begin());
    }

    operator std::array<T, N>() {
        std::array<T, N> res;
        allocator::copy<memcpy_t::device_to_host>(res.begin(), data);
        return res;
    }

    ~gpu_array() {
        allocator::deallocate(data);
    }

    __device__ iterator begin() { return data; }
    __device__ iterator end() { return data + N; }
    __device__ const_iterator begin() const { return data; }
    __device__ const_iterator end() const { return data + N; }

private:
    T* data;
};

template <typename T, std::size_t N>
__global__ void add_kernel(gpu_array<T,N> &r,
                           gpu_array<T,N> const&a1,
                           gpu_array<T,N> const&a2) {
    int i = blockIdx.x*blockDim.x + threadIdx.x;
    r.begin()[i] = a1.begin()[i] + a2.begin()[i];
}

template <typename T, std::size_t N>
gpu_array<T, N> operator+(gpu_array<T,N> const&a1,
                          gpu_array<T,N> const&a2)
{
    gpu_array<T, N> res;
    add_kernel<<<(N+255)/256, 256>>>(res, a1, a2);
    return res;
}

const int N = 1<<20;

int main() {
    std::array<float, N> x,y;

    for (int i = 0; i < N; i++) {
        x[i] = 1.0f;
        y[i] = 2.0f;
    } 

    gpu_array<float, N> dx{x};
    gpu_array<float, N> dy{y};

    std::array<float, N> res = dx + dy;

    for(const auto& elem : res) {
        std::cout << elem << ", ";
    }
}

也许还有其他许多错误,但是我被一个奇怪的错误困住了。 nvcc给我以下错误:

error: no match for 'operator<' (operand types are '<unresolved overloaded function    type>' and 'memcpy_t')
allocator::copy<memcpy_t::host_to_device>(data, host_arr.begin());

由于某种原因,它是否将我的枚举类模板参数视为operator<?顺便说一下,这是使用选项-arch=sm_70 -std=c++14进行编译的。我对C ++和CUDA的交互方式不了解,因此无法解决问题。

1 个答案:

答案 0 :(得分:2)

花了点头时间,但是这里的潜在问题是根据C ++标准的语法有缺陷。据我所知,它是由主机编译器产生的错误,这样做是完全正确的。有关所有血腥细节,请参考here

您使用copy专业化的代码应如下所示:

gpu_array(std::array<T, N> host_arr) {
    allocator::allocate(data);
    allocator::template copy<memcpy_t::host_to_device>(data, host_arr.begin());
}

gpu_array& operator=(gpu_array const& o) {
    allocator::allocate(data);
    allocator::template copy<memcpy_t::device_to_device>(data, o.begin());
}

operator std::array<T, N>() {
    std::array<T, N> res;
    allocator::template copy<memcpy_t::device_to_host>(res.begin(), data);
    return res;
}

这可能是有史以来最奇怪的语法,但这是使编译器将<作为模板令牌而不是运算符的条件。修复代码中所有地方的错误,该特定的编译器错误应消失。