openCL中矩阵乘法中的浮点数的求和减慢了内核时间

时间:2015-04-12 12:06:34

标签: opencl matrix-multiplication

我有两个OpenCL内核用于矩阵乘法。除了一些舍入错误外,两者都正常工作。

第一个是从文献中直截了当的。它使用三个行列有序浮点数来表示C = A * B.给定的矩阵是二次的,以确保错误设置的维度没有副误差。

内核1 - 带浮点数的矩阵乘法

kernel void matrixmult(
  global float* a, 
  global float* b, 
  global float* c, 
  const unsigned int rows, const unsigned int cols) 
{

  const unsigned int i = get_global_id(0);
  const unsigned int j = get_global_id(1);
  if ((i >= cols) || (j >= rows)) return;

  float sum = 0.0f;
  for (int k = 0; k < cols; k++) {
    sum += a[j*cols + k]*b[k*cols+i];
  }

  c[j*cols + i] = sum;
}

第二个OpenCL内核使用float4-arrays并调用点积。给出矩阵B作为原始矩阵的转置。 float4是JavaCL中定义的float - 数组的解释。所以我只是在主机上为我的两个内核创建相同的数组。

内核2 - 使用float4的矩阵乘法

kernel void matrixmult4(
  global const float4* a, 
  global const float4* bTransposed, 
  global float* c, 
  const unsigned int n) 
{

  const int rowsOut = get_global_size(0); 
  const int colsOut = get_global_size(1); 

  const unsigned int row = get_global_id(0);
  const unsigned int col = get_global_id(1);

  if ((col > colsOut) || (row > rowsOut)) return;

  const int indexA = row*n/4;
  const int indexB = col*n/4;

  float sum = 0.0f;

  for (int k = 0; k < n/4; k++) {
    sum += dot(a[indexA+k], bTransposed[indexB+k]);
  }

  c[row*colsOut + col] = sum;
}

我的问题:第一个内核运行速度比第二个内核快〜5倍(GTX460约70毫秒)(约350毫秒)。主要消费从总和线上升

    sum += dot(a[indexA+k], bTransposed[indexB+k]);

如果我使用+而不是+=,第二个内核会像第一个内核一样运行,但当然矩阵乘法是错误的。

是否需要同步添加sum?它在同一个内核实例中使用,而不是其他地方。

UPDATE 这是生成的SPIR(?)代码(Intel HD 5100)

来自内核的Sniplet sum += ...

%32 = load i32* %indexA, align 4, !tbaa !12
%33 = load i32* %k, align 4, !tbaa !12
%34 = add nsw i32 %32, %33
%35 = sext i32 %34 to i64
%36 = load <4 x float> addrspace(1)** %1, align 8, !tbaa !9
%37 = getelementptr inbounds <4 x float> addrspace(1)* %36, i64 %35
%38 = load <4 x float> addrspace(1)* %37, align 16, !tbaa !10
%39 = load i32* %indexB, align 4, !tbaa !12
%40 = load i32* %k, align 4, !tbaa !12
%41 = add nsw i32 %39, %40
%42 = sext i32 %41 to i64
%43 = load <4 x float> addrspace(1)** %2, align 8, !tbaa !9
%44 = getelementptr inbounds <4 x float> addrspace(1)* %43, i64 %42
%45 = load <4 x float> addrspace(1)* %44, align 16, !tbaa !10
%46 = call float @_Z3dotDv4_fS_(<4 x f+loat> %38, <4 x float> %45)
%47 = load float* %sum, align 4, !tbaa !13
%48 = fadd float %47, %46
store float %48, float* %sum, align 4, !tbaa !13
br label %49

简化sum =版本的SPIR代码仅排除以%47%48

开头的行

fadd会导致这样的开销吗?

0 个答案:

没有答案