将内部函数作为模板参数传递

时间:2018-02-19 19:24:13

标签: c++ templates cuda

我试图将atomicAdd函数作为模板参数传递给另一个函数。

这是我的Kernel1:

template<typename T, typename TAtomic>
__global__ void myfunc1(T *address, TAtomic atomicFunc) {
    atomicFunc(address, 1);
}

尝试1:

myfunc1<<<1,1>>>(val.dev_ptr, atomicAdd);

由于编译器无法匹配预期的函数签名,因此无效。

尝试2: 首先,我将atomicAdd包装到一个名为MyAtomicAdd的自定义函数中。

template<typename T>
__device__ void MyAtomicAdd(T *address, T val) {
    atomicAdd(address, val);
}

然后,我定义了一个名为&#34; TAtomic&#34;的函数指针。并将TAtomic声明为模板参数。

typedef void (*TAtomic)(float *,float);

template<typename T, TAtomic atomicFunc>
__global__ void myfunc2(T *address) {
    atomicFunc(address, 1);
}

myfunc2<float, MyAtomicAdd><<<1,1>>>(dev_ptr);
CUDA_CHECK(cudaDeviceSynchronize());

实际上,Try 2有效。但是,我不想使用typedef。我需要更通用的东西。

尝试3: 只需将MyAtomicAdd传递给myfunc1。

myfunc1<<<1,1>>>(dev_ptr, MyAtomicAdd<float>);
CUDA_CHECK(cudaDeviceSynchronize());

编译器可以编译代码。但是当我运行该程序时,报告错误:

"ERROR in /home/liang/groute-dev/samples/framework/pagerank.cu:70: invalid program counter (76)"

我只是想知道,为什么尝试3不起作用?任何简单或温和的方式都可以实现这一要求吗?谢谢。

1 个答案:

答案 0 :(得分:1)

尝试3不起作用,因为您试图在主机代码中获取__device__函数的地址,这在CUDA中是非法的:

myfunc1<<<1,1>>>(dev_ptr, MyAtomicAdd<float>);
                          ^
                          effectively a function pointer - address of a __device__ function

CUDA中的这种使用尝试将解析为某种“地址” - 但它是垃圾,所以当您尝试将其用作设备代码中的实际函数入口点时,您会收到遇到的错误:{{ 1}}(或者在某些情况下,只是invalid program counter)。

您可以通过将内在函数包装在仿函数而不是裸illegal address函数中来使您的Try 3方法有效(没有typedef):

__device__

我们也可以实现一个稍微简单的版本,让仿函数模板类型从内核模板类型中推断出来:

$ cat t48.cu
#include <stdio.h>

template<typename T>
__device__ void MyAtomicAdd(T *address, T val) {
    atomicAdd(address, val);
}


template <typename T>
struct myatomicadd
{
  __device__ T operator()(T *addr, T val){
    return atomicAdd(addr, val);
  }
};

template<typename T, typename TAtomic>
__global__ void myfunc1(T *address, TAtomic atomicFunc) {
    atomicFunc(address, (T)1);
}


int main(){

  int *dev_ptr;
  cudaMalloc(&dev_ptr, sizeof(int));
  cudaMemset(dev_ptr, 0, sizeof(int));
//  myfunc1<<<1,1>>>(dev_ptr, MyAtomicAdd<int>);
  myfunc1<<<1,1>>>(dev_ptr, myatomicadd<int>());
  int h = 0;
  cudaMemcpy(&h, dev_ptr, sizeof(int), cudaMemcpyDeviceToHost);
  printf("h = %d\n", h);
  return 0;
}
$ nvcc -arch=sm_35 -o t48 t48.cu
$ cuda-memcheck ./t48
========= CUDA-MEMCHECK
h = 1
========= ERROR SUMMARY: 0 errors
$

在CUDA中使用设备函数指针的更多处理方法与this answer相关联。