我想计算CUDA中所有列的总和以及矩阵的所有行的总和。一种方法是使用BLAS的SGEMV
子程序,将矩阵乘以1的向量。
但是,这会导致矩阵的两次扫描,假设它比L1缓存大得多:一个用于行,另一个用于列。另外,我计划进一步修改其他运算符的代码,所以这就是我编写自己的内核的原因。
到目前为止,我的方法是将矩阵分解为大小为32 x 32
的子矩阵。每个线程块将这样的子矩阵加载到共享内存中,计算子矩阵的行和列的总和,并将它们原子地添加到适当的输出(下面的row
和col
)。这样,矩阵数据只需要从VRAM读取一次。
为简单起见,代码假定矩阵为n x n
,n % 32 == 0
且线程块为32 x 32
__global__ void sum_cols_and_rows(size_t n, const float* matrix, float* col, float* row)
{
__shared__ float sh[32][32];
size_t x = blockDim.x * blockIdx.x + threadIdx.x;
size_t y = blockDim.y * blockIdx.y + threadIdx.y;
float sum = matrix[x + n * y];
sh[threadIdx.x][threadIdx.y] = sum;
for(unsigned w = 16; w >= 1; w /= 2)
sum += __shfl_down(sum, w);
const size_t laneID = threadIdx.x & 0x1f; // 32-1
if(laneID == 0)
atomicAdd(row + y, sum);
__syncthreads();
sum = sh[threadIdx.y][threadIdx.x]; // swapped indexes
for(unsigned w = 16; w >= 1; w /= 2)
sum += __shfl_down(sum, w);
if(laneID == 0)
atomicAdd(col + blockDim.x * blockIdx.x + threadIdx.y, sum);
}
// launch :
sum_cols_and_rows<<<dim3(n/32, n/32), dim3(32, 32), 32*32*sizeof(float)>>>(n, matrix, col, row);
然而,表现相当令人失望。我在GTX 980上看到了大约20%的理论224GB / s内存带宽,即使在大型矩阵上,例如 16384x16384。
有没有办法让这种方法达到理论带宽限制?
答案 0 :(得分:1)
在您的解决方案中,矩阵的每个NxN块都由单独的NxN线程块处理。实际上,每个单独的线程都做很少的工作,因此开销在实际计算中占主导地位您可以通过让线程块处理多个矩阵块来改进它。
但是有一个更简单的解决方案,每个矩阵块只使用N个线程,其中一个线程对整个列进行求和。
实施与此类似:
__global__ void sum_cols_and_rows(size_t n, const float* matrix, float* col, float* row)
{
size_t laneID = threadIdx.x & 31;
size_t x = blockDim.x * blockIdx.x + threadIdx.x;
size_t y = N_ITERATIONS * blockIdx.y;
size_t idx = y * n + x;
float vertical = 0;
for(int i = 0; i < N_ITERATIONS; i++) {
float v = matrix[idx];
vertical += v;
for(unsigned w = 16; w >= 1; w /= 2)
v += __shfl_down(v, w);
if(laneID == 0)
atomicAdd(&row[y], v);
y++;
idx += n;
}
atomicAdd(&col[x], vertical);
}
此处的可调参数是每个线程组的warp数和每个矩阵块中的行数(N_ITERATIONS
)。较大的值可能会降低开销,但代价是并行性。
另一个尝试的想法是vectorized loading - 其中一个:
float2 v2 = reinterpret_cast<float2*>(matrix)[idx];
float v = v2.x + v2.y;
float4 v4 = reinterpret_cast<float4*>(matrix)[idx];
float v = (v4.x + v4.y) + (v4.z + v4.w);