我想使用cublas进行以下矩阵 - 矩阵乘法:
cA(M by K) * cB(K by N) => cAout(M by N)
我指定cA,其中K为主要索引,cB为N,为主要索引。根据cublas-4.0手册,我应该这样做:
HANDLE_ERROR(cublasSgemm(hdl, CUBLAS_OP_N, CUBLAS_OP_N, M, K, N, &alpha, cA, K, cB, N, &beta, cAout, N));
但它不起作用。相反,以下代码通过精简切换cA和cB来生成预期结果:
HANDLE_ERROR(cublasSgemm(hdl, CUBLAS_OP_N, CUBLAS_OP_N, N, K, M, &alpha, cB, N, cA, K, &beta, cAout, N));
我使用的cublas版本是4.1.28。函数参数是否有约定变化?谢谢!
答案 0 :(得分:3)
回想一下,CUBLAS使用列主要存储约定。假设这些矩阵不是某个较大矩阵的一部分,则cA的前导维数为M,cB的前导维数为K,cAout的前导维数为M.因此,您的SGEMM调用应该读取
HANDLE_ERROR(cublasSgemm(hdl, CUBLAS_OP_N, CUBLAS_OP_N, M, K, N, &alpha, cA, M, cB, K, &beta, cAout, M));