基于简单cuda的矩阵乘法中的nan误差

时间:2014-04-11 04:44:29

标签: matrix cuda nan multiplication

I started to write some simple cuda code recently. Please see my code below. 

#include "matrix_multiplication.h"

void MatMul(const Matrix A, const Matrix B, Matrix C)
{
  Matrix d_A;
  d_A.width = A.width;
  d_A.height = A.height;
  size_t size = A.width*A.height*sizeof(float);
  cudaError_t err = cudaMalloc(&d_A.elements,size);
  printf("CUDA malloc A: %s\n",cudaGetErrorString(err));
  err = cudaMemcpy(d_A.elements,A.elements,size,cudaMemcpyHostToDevice);
  printf("Copy A to device: %s\n",cudaGetErrorString(err));

  Matrix d_B;
  d_B.width = B.width;
  d_B.height = B.height;
  size = B.width*B.height*sizeof(float);
  err = cudaMalloc(&d_B.elements,size);
  printf("CUDA malloc B: %s\n",cudaGetErrorString(err));
  err = cudaMemcpy(d_B.elements,B.elements,size,cudaMemcpyHostToDevice);
  printf("Copy B to device: %s\n",cudaGetErrorString(err));

  Matrix d_C;
  d_C.width = C.width;
  d_C.height = C.height;
  size = C.width*C.height*sizeof(float);
  err = cudaMalloc(&d_C.elements,size);
  printf("CUDA malloc C: %s\n",cudaGetErrorString(err));

  dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE);
  dim3 dimGrid((B.width + dimBlock.x -1)/dimBlock.x,(A.height +
        dimBlock.y -1)/dimBlock.y);
  MatMulKernel<<<dimGrid,dimBlock>>>(d_A,d_B,d_C);
  err = cudaThreadSynchronize();
  printf("Run kernel: %s\n",cudaGetErrorString(err));

  err =
    cudaMemcpy(C.elements,d_C.elements,size,cudaMemcpyDeviceToHost);
  printf("Copy C off of device: %s\n",cudaGetErrorString(err));

  cudaFree(d_A.elements);
  cudaFree(d_B.elements);
  cudaFree(d_C.elements);
}

__global__ void MatMulKernel(Matrix A, Matrix B, Matrix C)
{
  float Cvalue = 0.0;
  int row = blockIdx.y * blockDim.y + threadIdx.y;
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  if(row > A.height || col > B.width) return;

  for(int e = 0; e < A.width; ++e)
  {
    Cvalue +=(A.elements[row*A.width + e]) * (B.elements[e*B.width +
        col]);
  }
  C.elements[row*C.width + col] = Cvalue;
}

结果如下所示(5乘5矩阵),有时在某些元素中包含纳米或无意义值。

2 1 2 0 0

nan 1 2 1 1

1 0 2 1 1

nan 0 2 1 0

3 1 4 2 1

我无法弄清楚原因。请帮帮我。非常感谢。

1 个答案:

答案 0 :(得分:2)

内核中的绑定检查应该是这样的:

if(row >= A.height || col >= B.width) return;

如果没有 = 符号,您将(宽度+高度 - 1)额外线程包含在计算中,这将执行超出范围的内存访问并导致未定义的行为。