在CUDA内核中使用cublasStbsv

时间:2016-06-09 12:29:12

标签: cuda cublas

我试图在我的CUDA内核中使用cublasStbsv函数求解方程式。内核代码如下:

__global__ void invokeDeviceCublasSgemm(cublasStatus_t *returnValue,
                                        int n,
                                        const float *d_alpha,
                                        const float *d_A,
                                        const float *d_B,
                                        const float *d_beta,
                                        float *d_C)
{
    cublasHandle_t cnpHandle;
    cublasStatus_t status = cublasCreate(&cnpHandle);

    if (status != CUBLAS_STATUS_SUCCESS)
    {
        *returnValue = status;
        return;
    }

//    /* Perform operation using cublas */
//    status =
//        cublasSgemm(cnpHandle,
//                    CUBLAS_OP_N, CUBLAS_OP_N,
//                    n, n, n,
//                    d_alpha,
//                    d_A, n,
//                    d_B, n,
//                    d_beta,
//                    d_C, n);

    float d_AA[5*5];
    float d_BB[5];
//    float d_X[5];

    for(int i=0;i<5;i++)
    {
        for(int j=0;j<5;j++)
        {
            if(i==j)
            {
                d_AA[i*5+j] = i;
            }else
            {
                d_AA[i*5+j] = 0;
            }

        }
        d_BB[i] = i*i;
    }

    status = cublasStbsv(cnpHandle,
                         CUBLAS_FILL_MODE_UPPER,
                         CUBLAS_OP_N,
                         CUBLAS_DIAG_NON_UNIT,
                         n,n,
                         d_AA,
                         5,
                         d_BB,
                         1);

    for(int i=0;i<5;i++)
    {
           printf("B i %d %f \n",i,d_BB[i]);
    }

    cublasDestroy(cnpHandle);

    *returnValue = status;
}

我不明白为什么会收到以下错误:

  

开始/ home / xavier / Bureau / Developpement / Cuda / build-Cuda_CUBLAS-Qt_5_6_0_gcc_64-Release / Cuda_CUBLAS ......

     

simpleDevLibCUBLAS测试运行...
  GPU设备0:&#34; GeForce GTX 750 Ti&#34;具有计算能力5.0

     

将测试主机和设备API    **进入SBSV参数号7时有非法值
  B i 0 0.000000
  B i 1 1.000000
  B i 2 4.000000
  B i 3 9.000000
  B i 4 16.000000
  !!!! CUBLAS Device API调用失败,代码为7

我不明白我应该使用哪个函数来解决线性方程 - cublasStpsvcublasStrs。有人能帮助我吗?

1 个答案:

答案 0 :(得分:1)

tbsp()期望一个三角形带状矩阵; tpsv()期望以打包格式存储三角矩阵;并且trsv()需要一个密集的矩阵,只使用上/下部分。

根据您的代码,我认为您需要trsv()