cublasDGemm奇怪的结果

时间:2017-08-16 07:41:02

标签: c cuda cublas

我正在尝试使用cublasDgemm()通过其转置来计算矩阵的乘积。我希望从我的代码中得到的输入矩阵和输出如下(分别为A和C):

    | 1 4 7 |        | 66 78 |
A = | 2 5 8 |    C = | 78 93 |

然而,我得到了奇怪的结果,我很难理解cublas / cuda使用的维度(列专业)。任何提示将不胜感激!

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <cuda_runtime.h>
#include "cublas_v2.h"
#define M 3
#define N 2
#define IDX2C(i,j,ld) (((j)*(ld))+(i))

int main (void){
    cudaError_t cudaStat;    
    cublasStatus_t stat;
    cublasHandle_t handle;
    int i, j;
    double *devPtrA, *devPtrC;
    double *a = 0, *c = 0;

    const double alpha = 1;
    const double beta = 0;

    // initialize host arrays
    a = (double *)malloc (M * N * sizeof (*a));
    c = (double *)malloc (N * N * sizeof (*c));
    if (!a || !c) {
        printf ("host memory allocation failed");
        return EXIT_FAILURE;
    }

    // fill input array
    for (j = 0; j < N; j++) {
        for (i = 0; i < M; i++) {
            a[IDX2C(i,j,M)] = (double)(i * M + j + 1);
            printf ("%7.0f", a[IDX2C(i,j,M)]);
        }
        printf ("\n");
    }

    // set device to 0 (for double processing)
    cudaStat = cudaSetDevice(0);
    if (cudaStat != cudaSuccess) {
        printf("could not set device 0");
        return EXIT_FAILURE;
    }

    // allocate device arrays
    cudaStat = cudaMalloc ((void**)&devPtrA, M*N*sizeof(*a));
    if (cudaStat != cudaSuccess) {
        printf ("device memory allocation of A failed");
        return EXIT_FAILURE;
    }
    cudaStat = cudaMalloc ((void**)&devPtrC, N*N*sizeof(*c));
    if (cudaStat != cudaSuccess) {
        printf ("device memory allocation of C failed");
        return EXIT_FAILURE;
    }

    // create the cublas handle
    stat = cublasCreate(&handle);
    if (stat != CUBLAS_STATUS_SUCCESS) {
        printf ("CUBLAS initialization failed\n");
        return EXIT_FAILURE;
    }

    // set the matrix a
    stat = cublasSetMatrix (M, N, sizeof(*a), a, M, devPtrA, M);
    if (stat != CUBLAS_STATUS_SUCCESS) {
        printf ("data download failed");
        cudaFree (devPtrA);
        cudaFree (devPtrC);
        cublasDestroy(handle);
        return EXIT_FAILURE;
    }

    stat = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, M, M, N, &alpha, devPtrA, M, devPtrA, M, &beta, devPtrC, M);
    if (stat!= CUBLAS_STATUS_SUCCESS) {
        switch (stat) {
            case CUBLAS_STATUS_NOT_INITIALIZED:
                printf("CUBLAS_STATUS_NOT_INITIALIZED\n");
                break;
            case CUBLAS_STATUS_INVALID_VALUE:
                printf("CUBLAS_STATUS_INVALID_VALUE\n");
                break;
            case CUBLAS_STATUS_ARCH_MISMATCH:
                printf("CUBLAS_STATUS_ARCH_MISMATCH\n");
                break;
            case CUBLAS_STATUS_EXECUTION_FAILED:
                printf("CUBLAS_STATUS_EXECUTION_FAILED\n");
                break;
            default:
                printf("??\n");
        } 

        printf("Error: %d\n", (int)stat);
        cudaFree (devPtrA);
        cudaFree (devPtrC);
        cublasDestroy(handle);
        return EXIT_FAILURE;
    }

    // get matrix c
    stat = cublasGetMatrix (N, N, sizeof(*c), devPtrC, N, c, N);
    if (stat != CUBLAS_STATUS_SUCCESS) {
        printf ("data upload failed");
        cudaFree (devPtrC);
        cublasDestroy(handle);
        return EXIT_FAILURE;
    }

    // cleanup cuda/cublas
    cudaFree (devPtrA);
    cudaFree (devPtrC);
    cublasDestroy(handle);

    // print result
    for (j = 0; j < N; j++) {
        for (i = 0; i < N; i++) {
            printf ("%7.0f", c[IDX2C(i,j,M)]);
        }
        printf ("\n");
    }

    // clear host data
    free(a);
    free(c);
    return EXIT_SUCCESS;
}

1 个答案:

答案 0 :(得分:3)

第一个问题是你是以行主格式填充矩阵A.要解决这个问题,只需交换i和j索引即可。在列主要格式中,前导维度应为行数,在您的情况下为N。

for (j = 0; j < N; j++) { 
    for (i = 0; i < M; i++) {
        a[IDX2C(j,i,N)] = (double)(i * M + j + 1);
        printf ("%7.0f", a[IDX2C(j,i,N)]);
    }
    printf ("\n");
}

您还在cublasDgemm调用中交换尺寸,它应该如下所示:

stat = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, N, N, M, &alpha, devPtrA, N, devPtrA, N, &beta, devPtrC, N);

最后,你使用M作为C矩阵的主要维度,它应该是N:

printf ("%7.0f", c[IDX2C(i,j,N)]);