如何针对dgemm阻塞矩阵乘法优化此C程序?

时间:2018-02-27 22:33:45

标签: c optimization compiler-optimization matrix-multiplication

我有这个用于矩阵乘法的C程序,我已经读过并试过循环展开......但除此之外我不知道如何加速C程序。有人能指出我还能做些什么吗?

const char* dgemm_desc = "Simple blocked dgemm.";

#if !defined(BLOCK_SIZE)
#define BLOCK_SIZE 41
#endif

#define min(a,b) (((a)<(b))?(a):(b))

/* This routine performs a dgemm operation
 *  C := C + A * B
 * where A, B, and C are lda-by-lda matrices stored in column-major format. 
 * On exit, A and B maintain their input values. */
void square_dgemm (int lda, double* A, double* B, double* C)
{
  /* For each block-row of A */
  for (int i = 0; i < lda; i += BLOCK_SIZE)
    /* For each block-column of B */
    for (int j = 0; j < lda; j += BLOCK_SIZE)
      /* Accumulate block dgemms into block of C */
      for (int k = 0; k < lda; k += BLOCK_SIZE)
      {
        /* Correct block dimensions if block "goes off edge of" the matrix */
        int M = min (BLOCK_SIZE, lda-i);
        int N = min (BLOCK_SIZE, lda-j);
        int K = min (BLOCK_SIZE, lda-k);

        /* Perform individual block dgemm */
if( (M % BLOCK_SIZE <= 2) && (N % BLOCK_SIZE <= 2) && (K % BLOCK_SIZE <= 1 ))
        {
            do_block_fast(lda, M, N, K, A + i + k*lda, B+ k + j *lda, C + i + j*lda);
        }else
        {
            do_block(lda, M, N, K, A + i + k*lda, B + k + j*lda, C + i + j*lda);

        }
      }
}

主程序,调用子程序do_block_fast()和do_block()。

void do_block_fast (int lda, int M, int N, int K, double* A, double* B, double* C)
{
 static double a[BLOCK_SIZE*BLOCK_SIZE] __attribute__ ((aligned (16)));
 // make a local aligned copy of A's block
   for( int j = 0; j < K; j++ )
    for( int i = 0; i < M; i++ )
     a[i+j*BLOCK_SIZE] = A[i+j*lda];
   /* For each row i of A */
    for (int i = 0; i < M; ++i)
    /* For each column j of B */
         for (int j = 0; j < N; ++j)
         {
           /* Compute C(i,j) */
             double cij = C[i+j*lda];
             //for (int k = 0; k < K; ++k){
              //  cij += a[i+k*BLOCK_SIZE] * B[k+j*lda];
            // }

             for (int k = 0; k < K; k+= 41)
             {
                cij += a[i+k*BLOCK_SIZE] * B[k+j*lda];
                cij += a[i+(k+1)*BLOCK_SIZE] * B[(k+1)+j*lda];
                cij += a[i+(k+2)*BLOCK_SIZE] * B[(k+2)+j*lda];
                cij += a[i+(k+3)*BLOCK_SIZE] * B[(k+3)+j*lda];
                cij += a[i+(k+4)*BLOCK_SIZE] * B[(k+4)+j*lda];
                cij += a[i+(k+5)*BLOCK_SIZE] * B[(k+5)+j*lda];
                cij += a[i+(k+6)*BLOCK_SIZE] * B[(k+6)+j*lda];
                cij += a[i+(k+7)*BLOCK_SIZE] * B[(k+7)+j*lda];
                cij += a[i+(k+8)*BLOCK_SIZE] * B[(k+8)+j*lda];
                cij += a[i+(k+9)*BLOCK_SIZE] * B[(k+9)+j*lda];
                cij += a[i+(k+10)*BLOCK_SIZE] * B[(k+10)+j*lda];
                cij += a[i+(k+11)*BLOCK_SIZE] * B[(k+11)+j*lda];
                cij += a[i+(k+12)*BLOCK_SIZE] * B[(k+12)+j*lda];
                cij += a[i+(k+13)*BLOCK_SIZE] * B[(k+13)+j*lda];
                cij += a[i+(k+14)*BLOCK_SIZE] * B[(k+14)+j*lda];
                cij += a[i+(k+15)*BLOCK_SIZE] * B[(k+15)+j*lda];
                cij += a[i+(k+16)*BLOCK_SIZE] * B[(k+16)+j*lda];
                cij += a[i+(k+17)*BLOCK_SIZE] * B[(k+17)+j*lda];
                cij += a[i+(k+18)*BLOCK_SIZE] * B[(k+18)+j*lda];
                cij += a[i+(k+19)*BLOCK_SIZE] * B[(k+19)+j*lda];
                cij += a[i+(k+20)*BLOCK_SIZE] * B[(k+20)+j*lda];
                cij += a[i+(k+21)*BLOCK_SIZE] * B[(k+21)+j*lda];
                cij += a[i+(k+22)*BLOCK_SIZE] * B[(k+22)+j*lda];
                cij += a[i+(k+23)*BLOCK_SIZE] * B[(k+23)+j*lda];
                cij += a[i+(k+24)*BLOCK_SIZE] * B[(k+24)+j*lda];
                cij += a[i+(k+25)*BLOCK_SIZE] * B[(k+25)+j*lda];
                cij += a[i+(k+26)*BLOCK_SIZE] * B[(k+26)+j*lda];
                cij += a[i+(k+27)*BLOCK_SIZE] * B[(k+27)+j*lda];
                cij += a[i+(k+28)*BLOCK_SIZE] * B[(k+28)+j*lda];
                cij += a[i+(k+29)*BLOCK_SIZE] * B[(k+29)+j*lda];
                cij += a[i+(k+30)*BLOCK_SIZE] * B[(k+30)+j*lda];
                cij += a[i+(k+31)*BLOCK_SIZE] * B[(k+31)+j*lda];
                cij += a[i+(k+32)*BLOCK_SIZE] * B[(k+32)+j*lda];
                cij += a[i+(k+33)*BLOCK_SIZE] * B[(k+33)+j*lda];
                     ...
                     ...
               }
             C[i+j*lda] = cij;
          }
}

子程序do_block_fast()...基本上因为我知道K = 41的值,所以我刚刚将最里面的for循环展开了41。

static void do_block (int lda, int M, int N, int K, double* A, double* B, double* C)
{
  /* For each row i of A */
  for (int i = 0; i < M; ++i)
    /* For each column j of B */
    for (int j = 0; j < N; ++j)
    {
      /* Compute C(i,j) */
      double cij = C[i+j*lda];
      if(K % 8 == 0)
      for(int k = 0; k < K; k += 8){
        cij += A[i+k*lda] * B[k+j*lda];
        cij += A[i+(k+1)*lda] * B[(k+1)+j*lda];
        cij += A[i+(k+2)*lda] * B[(k+2)+j*lda];
        cij += A[i+(k+3)*lda] * B[(k+3)+j*lda];
        cij += A[i+(k+4)*lda] * B[(k+4)+j*lda];
        cij += A[i+(k+5)*lda] * B[(k+5)+j*lda];
        cij += A[i+(k+6)*lda] * B[(k+6)+j*lda];
        cij += A[i+(k+7)*lda] * B[(k+7)+j*lda];
              }
      else if(K % 7 == 0)
      for(int k = 0; k < K; k += 7){
        cij += A[i+k*lda] * B[k+j*lda];
        cij += A[i+(k+1)*lda] * B[(k+1)+j*lda];
        cij += A[i+(k+2)*lda] * B[(k+2)+j*lda];
        cij += A[i+(k+3)*lda] * B[(k+3)+j*lda];
        cij += A[i+(k+4)*lda] * B[(k+4)+j*lda];
        cij += A[i+(k+5)*lda] * B[(k+5)+j*lda];
        cij += A[i+(k+6)*lda] * B[(k+6)+j*lda];
              }
      else if(K % 6 == 0)
      for(int k = 0; k < K; k += 6){
        cij += A[i+k*lda] * B[k+j*lda];
        cij += A[i+(k+1)*lda] * B[(k+1)+j*lda];
        cij += A[i+(k+2)*lda] * B[(k+2)+j*lda];
        cij += A[i+(k+3)*lda] * B[(k+3)+j*lda];
        cij += A[i+(k+4)*lda] * B[(k+4)+j*lda];
        cij += A[i+(k+5)*lda] * B[(k+5)+j*lda];
              }
      ...

      else
      for (int k = 0; k < K; ++k)
        cij += A[i+k*lda] * B[k+j*lda];
      C[i+j*lda] = cij;
    }
}

子程序do_block()。由于我不知道K的值,我只使用了分支语句,以展开最里面的for循环。

最后这是程序的结果

Size: 31    Mflop/s:  2361.82   Percentage:  5.51
Size: 32    Mflop/s:   2941.3   Percentage:  6.86
Size: 96    Mflop/s:  2861.72   Percentage:  6.67
Size: 97    Mflop/s:  2657.44   Percentage:  6.19
Size: 127   Mflop/s:  3028.53   Percentage:  7.06
Size: 128   Mflop/s:  2925.26   Percentage:  6.82
Size: 129   Mflop/s:   2823.6   Percentage:  6.58
Size: 191   Mflop/s:  2893.04   Percentage:  6.74
Size: 192   Mflop/s:  2925.12   Percentage:  6.82
Size: 229   Mflop/s:  2834.69   Percentage:  6.61
Size: 255   Mflop/s:  2963.94   Percentage:  6.91
Size: 256   Mflop/s:  2712.55   Percentage:  6.32
Size: 257   Mflop/s:   2905.3   Percentage:  6.77
Size: 319   Mflop/s:  2878.31   Percentage:  6.71
Size: 320   Mflop/s:  2855.04   Percentage:  6.66
Size: 321   Mflop/s:  2835.52   Percentage:  6.61
Size: 417   Mflop/s:  2971.67   Percentage:  6.93
Size: 479   Mflop/s:  2876.16   Percentage:  6.70
Size: 480   Mflop/s:  2813.88   Percentage:  6.56
Size: 511   Mflop/s:  2717.75   Percentage:  6.34
Size: 512   Mflop/s:  1740.14   Percentage:  4.06
Size: 639   Mflop/s:  2614.61   Percentage:  6.09
Size: 640   Mflop/s:  2453.03   Percentage:  5.72
Size: 767   Mflop/s:  2494.76   Percentage:  5.82
Size: 768   Mflop/s:  2389.77   Percentage:  5.57
Size: 769   Mflop/s:   2779.2   Percentage:  6.48
Average percentage of Peak = 6.38822

0 个答案:

没有答案