假设A是M * N矩阵并存储在column-major中,我尝试使用此函数cublasSgemm_v2
,这是Cublas中的Matrix-Matrix乘法API
cublasSgemm_v2(handle,CUBLAS_OP_T,CUBLAS_OP_N,N,N,M,&al,A,N,A,M,&beta,A_result,N)
在调用此函数之前,我测试矩阵A并且它看起来不错,但它显示参数8是非法的,我不知道为什么。
所以我决定使用另一个API来计算A.tanspose * A cublas<t>syrk()
。返回的结果存储在矩阵的下部或上部,这意味着未引用矩阵的其余部分,以及如何编写内核以将元素复制到对称部分?
另一个问题是我的程序有时会崩溃(可能是三分之一的可能性)在代码的开头,如cudaMalloc或cbulascreate或其他地方,我只是在代码中间修改一些代码,并且之前运行了很多次,什么可能是因为这个?
谢谢
答案 0 :(得分:0)
您必须仔细阅读cublas gemm documentation。
有一种方法可以使用A' * A
直接计算cublas<T>gemm
,但这很棘手。
cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, N, K, M, &alpha,
A, M, A, M, &beta, B, N);
这是一种小黑客 - A是以列主要顺序和A(MxN)
存储的维度K = N
矩阵。
因此,您将获得B = A' * A
。