#pragma unroll语句可以用于清理编译器预先评估的代码吗?

时间:2014-09-26 14:49:41

标签: c++ optimization cuda dry compiler-optimization

在NVIDIA制作精良的reduction optimization documentation中,他们最终得到的warpReduce如下:

Template <unsigned int blockSize>
__device__ void warpReduce(volatile int* sdata, int tid) {
    if (blockSize >= 64) sdata[tid] += sdata[tid + 32]; 
    if (blockSize >= 32) sdata[tid] += sdata[tid + 16]; 
    if (blockSize >= 16) sdata[tid] += sdata[tid + 8]; 
    if (blockSize >= 8) sdata[tid] += sdata[tid + 4]; 
    if (blockSize >= 4) sdata[tid] += sdata[tid + 2]; 
    if (blockSize >= 2) sdata[tid] += sdata[tid + 1]; 
}

这会伤害我作为Don't Repeat Yourself的支持者的敏感性,为什么他们不能节省程序员时间并引入更像

的结构
Template <unsigned int blockSize>
__device__ void warpReduce(volatile int* sdata, int tid) {
    #pragma unroll(6)
    for (int i = 64; i>1; i>>=1){
        if (blockSize >= i) sdata[tid] += sdata[tid + (i/2)]; 
    }
}

这会编译成相同的代码吗?或者这会使编译器混淆太多而无法优化它?我知道case语句有类似的DRY问题,但我不知道有办法解决这个问题,因为它将运行时输入映射到编译器时间常量定义的模板内核,但我不知道如何让程序员需要很多次写出相同的东西是有帮助的。

更不用说,如果warp大小发生变化,那么它可能没有(但是谁知道)我引入的代码将更容易更改,甚至可以使用常量来定义warp大小,允许进行一次更改在这里以及依赖于warp大小的任何其他地方更改优化。

非常相关的子问题,我使用内置的warpSize阅读也是有问题的编译器优化原因。否则我会把它包含在上面的代码中。转到device_launch_parameters.h中的定义,它会调用内置设备查询__cudaGet_warpSize()。如果这不是问题,NVIDIA可以在该文档中提供warpReduce的完整,通用优化以供参考,但没有,我想知道为什么?

Tl; dr:我可以将其写为

Template <unsigned int blockSize>
__device__ void warpReduce(volatile int* sdata, int tid) {
    #pragma unroll(6)
    for (int i = warpSize * 2; i>1; i>>=1){
        if (blockSize >= i) sdata[tid] += sdata[tid + (i/2)]; 
    }
}

并享受与其文档相同的优化?

1 个答案:

答案 0 :(得分:2)

这是一个富有洞察力的观察(我认为)。根据我的测试代码:

#include <stdio.h>

template <unsigned int blockSize>
__device__ void warpReduce1(volatile int* sdata, int tid) {
    if (blockSize >= 64) sdata[tid] += sdata[tid + 32];
    if (blockSize >= 32) sdata[tid] += sdata[tid + 16];
    if (blockSize >= 16) sdata[tid] += sdata[tid + 8];
    if (blockSize >= 8) sdata[tid] += sdata[tid + 4];
    if (blockSize >= 4) sdata[tid] += sdata[tid + 2];
    if (blockSize >= 2) sdata[tid] += sdata[tid + 1];
}

template <unsigned int blockSize>
__device__ void warpReduce2(volatile int* sdata, int tid) {
    #pragma unroll 6
    for (int i = 64; i>1; i>>=1){
        if (blockSize >= i) sdata[tid] += sdata[tid + (i/2)];
    }
}

template <unsigned int blockSize>
__global__ void reduce6(int *g_idata, int *g_odata, unsigned int n) {
  extern __shared__ int sdata[];
  unsigned int tid = threadIdx.x;
  unsigned int i = blockIdx.x*(blockSize*2) + tid;
  unsigned int gridSize = blockSize*2*gridDim.x;
  sdata[tid] = 0;
  while (i < n){sdata[tid] += g_idata[i] + g_idata[i+blockSize]; i += gridSize; }
  __syncthreads();
  if (blockSize >= 512) { if (tid < 256) { sdata[tid] += sdata[tid + 256]; } __syncthreads(); }
  if (blockSize >= 256) { if (tid < 128) { sdata[tid] += sdata[tid + 128]; } __syncthreads(); }
  if (blockSize >= 128) { if (tid < 64) { sdata[tid] += sdata[tid + 64]; } __syncthreads(); }
  if (tid < 32) warpReduce1<blockSize>(sdata, tid);
  if (tid == 0) g_odata[blockIdx.x] = sdata[0];
}

#define DSIZE 1048576
#define BSIZE 256
#define NBLKS 64

int main(){
  int *h_data, *d_idata, *d_odata;
  h_data=(int *)malloc(DSIZE * sizeof(int));
  cudaMalloc(&d_idata, DSIZE * sizeof(int));
  cudaMalloc(&d_odata, NBLKS * sizeof(int));
  for (int i = 0; i < DSIZE; i++) h_data[i] = rand()%2;
  cudaMemcpy(d_idata, h_data, DSIZE*sizeof(int), cudaMemcpyHostToDevice);
  reduce6<BSIZE><<<64, BSIZE>>>(d_idata, d_odata, DSIZE);
  cudaMemcpy(h_data, d_odata, NBLKS*sizeof(int), cudaMemcpyDeviceToHost);
  return 0;
}

生成的机器代码似乎相同:

Fatbin elf code:
================
arch = sm_20
code version = [1,7]
producer = <unknown>
host = linux
compile_size = 64bit
identifier = t576.cu

        code for sm_20

Fatbin elf code:
================
arch = sm_20
code version = [1,7]
producer = cuda
host = linux
compile_size = 64bit
identifier = t576.cu

        code for sm_20
                Function : _Z7reduce6ILj256EEvPiS0_j
        .headerflags    @"EF_CUDA_SM20 EF_CUDA_PTX_SM(EF_CUDA_SM20)"
        /*0000*/         MOV R1, c[0x1][0x100];                            /* 0x2800440400005de4 */
        /*0008*/         S2R R0, SR_CTAID.X;                               /* 0x2c00000094001c04 */
        /*0010*/         S2R R2, SR_TID.X;                                 /* 0x2c00000084009c04 */
        /*0018*/         MOV R4, c[0x0][0x14];                             /* 0x2800400050011de4 */
        /*0020*/         MOV R10, RZ;                                      /* 0x28000000fc029de4 */
        /*0028*/         ISCADD R8, R0, R2, 0x9;                           /* 0x4000000008021d23 */
        /*0030*/         SHL.W R3, R2, 0x2;                                /* 0x6000c0000820de03 */
        /*0038*/         SHL R9, R4, 0x9;                                  /* 0x6000c00024425c03 */
        /*0040*/         ISETP.GE.U32.AND P0, PT, R8, c[0x0][0x30], PT;    /* 0x1b0e4000c081dc03 */
        /*0048*/         STS [R3], RZ;                                     /* 0xc9000000003fdc85 */
        /*0050*/         SSY 0xf0;                                         /* 0x6000000260000007 */
        /*0058*/     @P0 NOP.S;                                            /* 0x40000000000001f4 */
        /*0060*/         MOV32I R11, 0x4;                                  /* 0x180000001002dde2 */
        /*0068*/         IMAD.U32.U32 RZ, R1, RZ, RZ;                      /* 0x207e0000fc1fdc03 */
        /*0070*/         SSY 0xe8;                                         /* 0x60000001c0000007 */
        /*0078*/         NOP;                                              /* 0x4000000000001de4 */
        /*0080*/         IMAD.U32.U32 R6.CC, R8, R11, c[0x0][0x20];        /* 0x2017800080819c03 */
        /*0088*/         IADD R5, R8, 0x100;                               /* 0x4800c00400815c03 */
        /*0090*/         IMAD.U32.U32.HI.X R7, R8, R11, c[0x0][0x24];      /* 0x209680009081dc43 */
        /*0098*/         IMAD.U32.U32 R4.CC, R5, R11, c[0x0][0x20];        /* 0x2017800080511c03 */
        /*00a0*/         IADD R8, R8, R9;                                  /* 0x4800000024821c03 */
        /*00a8*/         LD.E R7, [R6];                                    /* 0x840000000061dc85 */
        /*00b0*/         IMAD.U32.U32.HI.X R5, R5, R11, c[0x0][0x24];      /* 0x2096800090515c43 */
        /*00b8*/         ISETP.LT.U32.AND P0, PT, R8, c[0x0][0x30], PT;    /* 0x188e4000c081dc03 */
        /*00c0*/         LD.E R4, [R4];                                    /* 0x8400000000411c85 */
        /*00c8*/         IADD R5, R7, R10;                                 /* 0x4800000028715c03 */
        /*00d0*/         IADD R10, R5, R4;                                 /* 0x4800000010529c03 */
        /*00d8*/     @P0 BRA 0x80;                                         /* 0x4003fffe800001e7 */
        /*00e0*/         NOP.S;                                            /* 0x4000000000001df4 */
        /*00e8*/         STS.S [R3], R10;                                  /* 0xc900000000329c95 */
        /*00f0*/         IMAD.U32.U32 RZ, R1, RZ, RZ;                      /* 0x207e0000fc1fdc03 */
        /*00f8*/         BAR.RED.POPC RZ, RZ, RZ, PT;                      /* 0x50ee0000ffffdc04 */
        /*0100*/         ISETP.GT.U32.AND P0, PT, R2, 0x7f, PT;            /* 0x1a0ec001fc21dc03 */
        /*0108*/    @!P0 LDS R5, [R3];                                     /* 0xc100000000316085 */
        /*0110*/    @!P0 LDS R4, [R3+0x200];                               /* 0xc100000800312085 */
        /*0118*/    @!P0 IADD R4, R5, R4;                                  /* 0x4800000010512003 */
        /*0120*/    @!P0 STS [R3], R4;                                     /* 0xc900000000312085 */
        /*0128*/         BAR.RED.POPC RZ, RZ, RZ, PT;                      /* 0x50ee0000ffffdc04 */
        /*0130*/         ISETP.GT.U32.AND P0, PT, R2, 0x3f, PT;            /* 0x1a0ec000fc21dc03 */
        /*0138*/    @!P0 LDS R5, [R3];                                     /* 0xc100000000316085 */
        /*0140*/    @!P0 LDS R4, [R3+0x100];                               /* 0xc100000400312085 */
        /*0148*/    @!P0 IADD R4, R5, R4;                                  /* 0x4800000010512003 */
        /*0150*/    @!P0 STS [R3], R4;                                     /* 0xc900000000312085 */
        /*0158*/         SSY 0x240;                                        /* 0x6000000380000007 */
        /*0160*/         BAR.RED.POPC RZ, RZ, RZ, PT;                      /* 0x50ee0000ffffdc04 */
        /*0168*/         ISETP.GT.U32.AND P0, PT, R2, 0x1f, PT;            /* 0x1a0ec0007c21dc03 */
        /*0170*/     @P0 NOP.S;                                            /* 0x40000000000001f4 */
        /*0178*/         SHL.W R3, R2, 0x2;                                /* 0x6000c0000820de03 */
        /*0180*/         LDS R5, [R3];                                     /* 0xc100000000315c85 */
        /*0188*/         LDS R4, [R3+0x80];                                /* 0xc100000200311c85 */
        /*0190*/         IADD R6, R5, R4;                                  /* 0x4800000010519c03 */
        /*0198*/         STS [R3], R6;                                     /* 0xc900000000319c85 */
        /*01a0*/         LDS R5, [R3];                                     /* 0xc100000000315c85 */
        /*01a8*/         LDS R4, [R3+0x40];                                /* 0xc100000100311c85 */
        /*01b0*/         IADD R7, R5, R4;                                  /* 0x480000001051dc03 */
        /*01b8*/         STS [R3], R7;                                     /* 0xc90000000031dc85 */
        /*01c0*/         LDS R5, [R3];                                     /* 0xc100000000315c85 */
        /*01c8*/         LDS R4, [R3+0x20];                                /* 0xc100000080311c85 */
        /*01d0*/         IADD R6, R5, R4;                                  /* 0x4800000010519c03 */
        /*01d8*/         STS [R3], R6;                                     /* 0xc900000000319c85 */
        /*01e0*/         LDS R5, [R3];                                     /* 0xc100000000315c85 */
        /*01e8*/         LDS R4, [R3+0x10];                                /* 0xc100000040311c85 */
        /*01f0*/         IADD R7, R5, R4;                                  /* 0x480000001051dc03 */
        /*01f8*/         STS [R3], R7;                                     /* 0xc90000000031dc85 */
        /*0200*/         LDS R5, [R3];                                     /* 0xc100000000315c85 */
        /*0208*/         LDS R4, [R3+0x8];                                 /* 0xc100000020311c85 */
        /*0210*/         IADD R6, R5, R4;                                  /* 0x4800000010519c03 */
        /*0218*/         STS [R3], R6;                                     /* 0xc900000000319c85 */
        /*0220*/         LDS R5, [R3];                                     /* 0xc100000000315c85 */
        /*0228*/         LDS R4, [R3+0x4];                                 /* 0xc100000010311c85 */
        /*0230*/         IADD R4, R5, R4;                                  /* 0x4800000010511c03 */
        /*0238*/         STS.S [R3], R4;                                   /* 0xc900000000311c95 */
        /*0240*/         ISETP.NE.AND P0, PT, R2, RZ, PT;                  /* 0x1a8e0000fc21dc23 */
        /*0248*/     @P0 BRA.U 0x278;                                      /* 0x40000000a00081e7 */
        /*0250*/    @!P0 MOV32I R3, 0x4;                                   /* 0x180000001000e1e2 */
        /*0258*/    @!P0 LDS R2, [RZ];                                     /* 0xc100000003f0a085 */
        /*0260*/    @!P0 IMAD.U32.U32 R4.CC, R0, R3, c[0x0][0x28];         /* 0x20078000a0012003 */
        /*0268*/    @!P0 IMAD.U32.U32.HI.X R5, R0, R3, c[0x0][0x2c];       /* 0x20868000b0016043 */
        /*0270*/    @!P0 ST.E [R4], R2;                                    /* 0x940000000040a085 */
        /*0278*/         EXIT;                                             /* 0x8000000000001de7 */
                ..........................................



Fatbin ptx code:
================
arch = sm_20
code version = [4,1]
producer = cuda
host = linux
compile_size = 64bit
compressed
identifier = t576.cu

我是否使用:

  if (tid < 32) warpReduce1<blockSize>(sdata, tid);

或:

  if (tid < 32) warpReduce2<blockSize>(sdata, tid);

回想起来,编译器可以轻松确定for循环的行程计数:

    for (int i = 64; i>1; i>>=1){

展开的循环将生成与内联代码完全相同的序列。因此编译器生成相同的代码。