当我在内核包装函数中应用模板技术时遇到问题。
这是我最初的想法:
//----------------------------------------
// cuda_demo.cuh
template<typename T>
void kernel_wrapper(T param);
//----------------------------------------
// cuda_demo.cu
#include <cuda.h>
#include <cuda_runtime.h>
#include "cuda_demo.cuh"
template<typename T>
__global__ void my_kernel(T param) {
// do something
}
template<typename T>
void kernel_wrapper(T param) {
my_kernel<<<1,1>>>(param);
}
//----------------------------------------
// main.cpp
#include "cuda_demo.cuh"
int main() {
int param = 10;
kernel_wrapper(param);
return 0;
}
很快,我发现应该在头文件中实现模板(请参见Why can templates only be implemented in the header file?)。
我从中得到两种解决方案,常见的解决方案是“将模板声明写在头文件中,然后在实现文件(例如.tpp)中实现该类,并在该实现文件的末尾包含此实现文件。标头”。
所以我更改代码:
//----------------------------------------
// cuda_demo.cuh
template<typename T>
void kernel_wrapper(T param);
#include "cuda_demo.cu"
//----------------------------------------
// cuda_demo.cu
#include <cuda.h>
#include <cuda_runtime.h>
template<typename T>
__global__ void my_kernel(T param) {
// do something
}
template<typename T>
void kernel_wrapper(T param) {
my_kernel<<<1,1>>>(param);
}
编译器给我以下错误:
error: expected primary-expression before < token
my_kernel<<<1,1>>>(param);
当我将所有cuda代码放入“ cuda_demo.cuh”时,会发生相同的错误。
然后我尝试了以下第二种解决方案:
//----------------------------------------
// cuda_demo.cuh
template<typename T>
void kernel_wrapper(T param);
//----------------------------------------
// cuda_demo.cu
#include <cuda.h>
#include <cuda_runtime.h>
#include "cuda_demo.cuh"
template<typename T>
__global__ void my_kernel(T param) {
// do something
}
template<typename T>
void kernel_wrapper(T param) {
my_kernel<<<1,1>>>(param);
}
template void kernel_wrapper<int>(int param);
这个很好用!但是在我的项目中,“ T”不是简单的类型,它可能是递归的,例如
Class_1<Class_2<Class_3<...>>>,
这意味着我无法提前弄清楚“ T”的具体类型。
有人知道如何解决吗?
谢谢。
答案 0 :(得分:0)
我发现了问题的实质。
所有cuda代码必须包含在.cu文件中,以便nvcc可以对其进行编译。感谢您的提醒。 @talonmies。
最近,我发现一些开源项目将cuda,C ++代码混合到.h或.cuh文件中,然后包括.cpp文件和.cu文件中的那些头文件。这使我相信cuda代码可以由gcc编译。
但是我终于发现,尽管许多.cpp文件都包含cuda代码,但它们都没有调用.cpp文件中的cuda函数。而且cuda函数调用仅存在于.cu文件中。
他们是如何做到的?答案是条件编译。这样,.cu文件中的cuda代码将由nvcc编译,而.cpp文件中的cuda代码将被gcc忽略。
对于我最初的问题,最有效的解决方案是将模板cuda代码的所有实现都写入头文件中,并仅在.cu文件中调用内核包装器。
我在这个问题上花了很多时间,希望我的经验能对您有所帮助。