C ++ struct functor作为函数模板参数

时间:2015-12-01 20:58:09

标签: c++ templates struct cuda

我很困惑,我知道CUDA和其他库允许使用模板结构作为仿函数。因此,我为神经网络课程设计了一些:

struct sigmoid
{
     sigmoid()=default;                                                           
     __device__ float operator()(const float x) const                                                                                     
    {                                                                                                                                     
         float exp_val = __expf(-x);                                                                                                       
         float denom = __fadd_rz(1.f,exp_val);                                                                                             
         return __fdividef(1.f,denom);                                                
    }                                                                     
};       

当我将它用于CUDA内核时,它的用法有点简单:

activate<sigmoid><<num_blocks_x,block_threads_x>>>(sigmoid(),output_ptr);

有关:

template <typename F>                                                                                                                     
__global__ void activate(F const& func, float * input)                                                                                    
{                                                                                                                                         
   int x = blockIdx.x * blockDim.x + threadIdx.x;                                                                                        
   input[x]  = func(input[x]);                                                                                                           
} 

但是我想将函数模板包装在中,调用CUDA内核,然后将其转发给它:

template <class A>                                                                                                             
thrust::host_vector<float> propagate (                                                                                                
                                       A func,                                                                 
                                       thrust::device_vector<float> & input                                                          
                                     ) const; 

我已将其实现为一个单独的标头,该标头包含在声明该类的标头的末尾。

class ann
{
...
};
#include ann_imp.hpp

和imp标题:

template <class A> inline                                                                                                                   
__host__ thrust::host_vector<float> ann::propagate (                                                                                        
                                                       A func,                                                                            
                                                       thrust::device_vector<float> & input                                               
                                                    ) const                                                                                 
{
     activate<func><<<num_blocks_x,block_threads_x>>>(func(),output_ptr);
}  

然而,当我调用实际的propagate方法时,我遇到了麻烦:

net.propagate<sigmoid>( sigmoid(), in_vec1 );

产地:

error: function "sigmoid::operator()" cannot be called with the given argument list
            object type is: sigmoid

当我不使用operator()但仅使用typename:

xor_net.propagate<sigmoid>( sigmoid, in_vec1 );

我明白了:

error: type name is not allowed

使用实际对象会产生相同的错误:

sigmoid func;
xor_net.propagate<sigmoid>( func, in_vec1 );

我尝试过使用参数A const& func等等,但无济于事。

如何传递struct functor,然后将其转发到CUDA内核?

修改 如果没有包装器,只需要调用激活函数:

activate<sigmoid><<<num_blocks_x,block_threads_x>>>(sigmoid(),output_ptr); 

1 个答案:

答案 0 :(得分:1)

你有:

 __device__ float operator()(const float x) const ...

该函数需要类型为float的参数。您是从ann::propagate将其称为:

activate<func><<<num_blocks_x,block_threads_x>>>(func(),output_ptr);
                                                 ^^^^^^

我认为该行必须是:

activate<A><<<num_blocks_x,block_threads_x>>>(func,output_ptr);
       ^^^^                                   ^^^^^     
       Fix the type                           Use the object.