我编写了一个struct
和一些包装“CUBLAS矩阵对象”的函数
struct
是:
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#define uint unsigned int
typedef struct {
uint rows;
uint cols;
float* devPtrvals;
} matrix;
alloc函数创建矩阵struct:
matrix* matrix_alloc(uint rows, uint cols)
{
cudaError_t cudaStat;
matrix* w = malloc(sizeof(matrix));
w->rows = rows;
w->cols = cols;
cudaStat = cudaMalloc((void**)&w->devPtrvals, sizeof(float) * rows * cols);
if(cudaStat != cudaSuccess) {
fprintf(stderr, "device memory allocation failed\n");
return NULL;
}
return w;
};
自由功能:
uint matrix_free(matrix* w)
{
cudaFree(w->devPtrvals);
free(w);
return 1;
};
从浮点数组设置矩阵值的函数:
uint matrix_set_vals(matrix* w, float* vals)
{
cublasStatus_t stat;
stat = cublasSetMatrix(w->rows, w->cols, sizeof(float),
vals, w->rows, w->devPtrvals, w->rows);
if(stat != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "data upload failed\n");
return 0;
}
return 1;
};
编写通用点积函数时遇到问题,该函数涵盖了矩阵的转置。这就是我写的:
matrix* matrix_dot(cublasHandle_t handle, char transA, char transB,
float alpha, matrix* v, matrix* w, float beta)
{
matrix* x = matrix_alloc(transA == CUBLAS_OP_N ? v->rows : v->cols,
transB == CUBLAS_OP_N ? w->cols : w->rows);
//cublasStatus_t cublasSgemm(cublasHandle_t handle,
// cublasOperation_t transa, cublasOperation_t transb,
// int m, int n, int k,
// const float *alpha,
// const float *A, int lda,
// const float *B, int ldb,
// const float *beta,
// float *C, int ldc)
cublasSgemm(handle, transA, transB,
transA == CUBLAS_OP_N ? v->rows : v->cols,
transB == CUBLAS_OP_N ? w->cols : w->rows,
transA == CUBLAS_OP_N ? v->cols : v->rows,
&alpha, v->devPtrvals, v->rows, w->devPtrvals,
w->rows, &beta, x->devPtrvals, x->rows);
return x;
};
示例:
我想要一个矩阵A:
1 2 3
4 5 6
7 8 9
10 11 12
这意味着:
float* a = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
matrix* A = matrix_alloc(4, 3);
matrix_set_vals(A, a);
并将其乘以转置的B:
1 2 3
4 5 6
也:
float* b = {1, 2, 3, 4, 5, 6};
matrix* B = matrix_alloc(2, 3);
matrix_set_vals(B, b);
A * B ^ T = C的结果:
14 32
32 77
50 122
68 167
我正在使用点功能:
matrix* C = matrix_dot(handle, CUBLAS_OP_N, CUBLAS_OP_T, 1.0, A, B, 0.0);
使用此功能时,我得到:** On entry to SGEMM parameter number 10 had an illegal value
我做错了什么?
答案 0 :(得分:1)
您的代码中存在2个问题。
首先,您将矩阵存储在row-major中,但cublas假定矩阵应存储在col-major中。对于col-major矩阵A
,应使用以下数据进行初始化。
float* a = {1,4,7,10,2,5,8,11,3,6,9,12};
实际上你可能已经注意到col-major cublas_gemm()
也可以用来计算行主矩阵乘法。由于存储在row-major中的矩阵M
的数据布局与存储在col-major中的转置矩阵M^T
的数据布局完全相同,如果存储中没有填充字节。所以,如果你想做
C_row = A_row * B_row
你可以改用
C_col_trans = B_col_trans * A_col_trans
C_row
和C_col_trans
的基础存储布局完全相同,以及A
和B
。
第二个问题是关于领先维度。当存储中没有填充字节时,行主矩阵的ld等于列数,而col-marjor矩阵的ld等于行数。
另一个问题是,您可能必须使用cublasOperation_t tansA
而不是char transA
。