使用cudnnSetConvolution2dDescriptor时,CUDNN_STATUS_BAD_PARAM错误的原因是什么?

时间:2019-04-20 10:30:51

标签: conv-neural-network convolution cudnn

我有 cuda 10.1 cudnn 7.5.0 ,gpu是 nvidia 940mx

    CHECK_CUDNN_ERROR(cudnnCreateConvolutionDescriptor(&convolutionDescriptor_));
    CHECK_CUDNN_ERROR(cudnnSetConvolution2dDescriptor(convolutionDescriptor_,
                                                      benchmarkInput.pad_h,
                                                      benchmarkInput.pad_w,
                                                      benchmarkInput.stride_h,
                                                      benchmarkInput.stride_w,
                                                      1,
                                                      1,
                                                      CUDNN_CONVOLUTION,
                                                      dataType));

cudnnSetConvolution2dDescriptor 不适用于int8,uint8,int32,int8x4,int8x32,uint8x4数据格式,并抛出 CUDNN_STATUS_BAD_PARAM 。但是使用 float 半浮点 double 可以正常工作。

我看了看文档,并说到有关填充,跨距,扩张和模式的无效值。但这不是问题的根源。

是不是cudnn 7.5.0不支持int8等格式?

完整代码为here

卷积的输入值:

//W H   C   N   K   S(filter_W) R(filter_H) pad_w   pad_h   stride_w    stride_h    out_w   out_h   input_stride_w  input_stride_h  filter_stride_w filter_stride_h
56  56  256 16  64  1   1   0   0   1   1   56  56  1   1   1   1
56  56  256 32  64  1   1   0   0   1   1   56  56  1   1   1   1
56  56  256 64  64  1   1   0   0   1   1   56  56  1   1   1   1
56  56  256 16  128 1   1   0   0   1   1   56  56  1   1   1   1
56  56  256 32  128 1   1   0   0   1   1   56  56  1   1   1   1
56  56  256 64  128 1   1   0   0   1   1   56  56  1   1   1   1

0 个答案:

没有答案