这是一个非常基本的C ++问题,用于计算GPU上的矩阵乘法。以下代码在技术上是MSL,但语法几乎完全相同。
Apple为计算A^T * B
提供了matrix multiplication example。我正在寻找一些帮助来修改它以简单地计算A * B
。
对此着色器的每次调用都在8 {8扇区C
上运行,gid
是该扇区在网格中的位置。这是来源:
// Note:
//
// (1) m is the number of rows in matrices A and C.
//
// (2) n is the number of columns in matrix A; number of rows in matrix B.
//
// (3) k is the number of columns in matrices B and C.
//
// (4) Matrix multiple computes C = A^T * B where A is m x n matrix (so
// that, A^T is n x m), B is n x k .
//
// (5) pbytes is stride in bytes from row to another of matrix A.
// pbytes should be multiple of 32, i.e. A is padded to be
// M x k matrix where M > m and P is multiple of 8.
//
// (6) Similarly qbytes is stride in bytes from one row to another
// of B, i.e. B is n x K matrix where K > k matrix where K is
// multiple of 8.
//
// (7) The output matrix C is the M x K matrix.
typedef struct
{
ushort m, k, n, pbytes, qbytes;
} MetalMatrixDim;
kernel void MatrixMultiply(const device float* A [[ buffer(0) ]],
const device float* B [[ buffer(1) ]],
device float* C [[ buffer(2) ]],
constant MetalMatrixDim& dims [[ buffer(3) ]],
ushort2 gid [[ thread_position_in_grid ]])
{
ushort m = dims.m;
ushort k = dims.k;
ushort n = dims.n;
ushort pbytes = dims.pbytes;
ushort qbytes = dims.qbytes;
// Multiply gid by 8 to get the absolute position in C
ushort2 gidIn = ushort2(gid.x << 3, gid.y << 3);
if (gidIn.x >= m || gidIn.y >= k) return;
const device float4* a = (const device float4*)(A + gidIn.x);
const device float4* b = (const device float4*)(B + gidIn.y);
C = (device float*)((device char*)C + gidIn.x*qbytes);
device float4* c = (device float4*)(C + gidIn.y);
const device float4* Bend = (const device float4*)((const device char*)B + qbytes*n);
float4 s0 = 0.0f, s1 = 0.0f, s2 = 0.0f, s3 = 0.0f;
float4 s4 = 0.0f, s5 = 0.0f, s6 = 0.0f, s7 = 0.0f;
float4 s8 = 0.0f, s9 = 0.0f, s10 = 0.0f, s11 = 0.0f;
float4 s12 = 0.0f, s13 = 0.0f, s14 = 0.0f, s15 = 0.0f;
do
{
float4 aCurr0 = a[0];
float4 aCurr1 = a[1];
float4 bCurr0 = b[0];
float4 bCurr1 = b[1];
s0 += (aCurr0.x * bCurr0);
s2 += (aCurr0.y * bCurr0);
s4 += (aCurr0.z * bCurr0);
s6 += (aCurr0.w * bCurr0);
s1 += (aCurr0.x * bCurr1);
s3 += (aCurr0.y * bCurr1);
s5 += (aCurr0.z * bCurr1);
s7 += (aCurr0.w * bCurr1);
s8 += (aCurr1.x * bCurr0);
s10 += (aCurr1.y * bCurr0);
s12 += (aCurr1.z * bCurr0);
s14 += (aCurr1.w * bCurr0);
s9 += (aCurr1.x * bCurr1);
s11 += (aCurr1.y * bCurr1);
s13 += (aCurr1.z * bCurr1);
s15 += (aCurr1.w * bCurr1);
a = (device float4*)((device char*)a + pbytes);
b = (device float4*)((device char*)b + qbytes);
} while(b < Bend);
c[0] = s0; c[1] = s1; c = (device float4*)((device char*)c + qbytes);
c[0] = s2; c[1] = s3; c = (device float4*)((device char*)c + qbytes);
c[0] = s4; c[1] = s5; c = (device float4*)((device char*)c + qbytes);
c[0] = s6; c[1] = s7; c = (device float4*)((device char*)c + qbytes);
c[0] = s8; c[1] = s9; c = (device float4*)((device char*)c + qbytes);
c[0] = s10; c[1] = s11; c = (device float4*)((device char*)c + qbytes);
c[0] = s12; c[1] = s13; c = (device float4*)((device char*)c + qbytes);
c[0] = s14; c[1] = s15;
}
我花了相当多的时间在这上面,但我提出的最好的是一个天真的解决方案,不考虑内存延迟。相反,我希望修改Apple的代码以消除A
的转置,同时仍允许GPU优化内存读/写。
有人可以帮助我吗?
编辑:这是我(非常)天真的实现。它的执行速度比Apple的内核慢大约100倍:
int pbytes = (int)dims.pbytes;
int qbytes = (int)dims.qbytes;
for (int row = 0; row < 8; row++) {
int aStart = (gidIn.y + row) * pbytes / 4;
for (int col = 0; col < 8; col++) {
int cIdx = gidIn.y + (row * qbytes / 4) + gidIn.x + col;
int bStart = gidIn.x + col;
float sum = 0.0f;
for (int i = 0; i < (pbytes / 4); i++) {
float prod = A[aStart + i] * B[bStart + (i * qbytes / 4)];
sum += prod;
}
C[cIdx] = sum;
}
}
此实现的问题在于它根本不优化内存访问。理想情况下,您可以一次读取和写入尽可能多的数据,从而允许编译器对操作进行矢量化。
答案 0 :(得分:0)
MetalPerformanceShaders框架有一个内置矩阵乘法内核,您可以将其编码到您的金属命令缓冲区中。我建议这样做,而不是浪费很多时间在这里。