我有 cuda 10.1 , cudnn 7.5.0 ,gpu是 nvidia 940mx 。
CHECK_CUDNN_ERROR(cudnnCreateConvolutionDescriptor(&convolutionDescriptor_));
CHECK_CUDNN_ERROR(cudnnSetConvolution2dDescriptor(convolutionDescriptor_,
benchmarkInput.pad_h,
benchmarkInput.pad_w,
benchmarkInput.stride_h,
benchmarkInput.stride_w,
1,
1,
CUDNN_CONVOLUTION,
dataType));
cudnnSetConvolution2dDescriptor 不适用于int8,uint8,int32,int8x4,int8x32,uint8x4数据格式,并抛出 CUDNN_STATUS_BAD_PARAM 。但是使用 float ,半浮点, double 可以正常工作。
我看了看文档,并说到有关填充,跨距,扩张和模式的无效值。但这不是问题的根源。
是不是cudnn 7.5.0不支持int8等格式?
完整代码为here
//W H C N K S(filter_W) R(filter_H) pad_w pad_h stride_w stride_h out_w out_h input_stride_w input_stride_h filter_stride_w filter_stride_h
56 56 256 16 64 1 1 0 0 1 1 56 56 1 1 1 1
56 56 256 32 64 1 1 0 0 1 1 56 56 1 1 1 1
56 56 256 64 64 1 1 0 0 1 1 56 56 1 1 1 1
56 56 256 16 128 1 1 0 0 1 1 56 56 1 1 1 1
56 56 256 32 128 1 1 0 0 1 1 56 56 1 1 1 1
56 56 256 64 128 1 1 0 0 1 1 56 56 1 1 1 1