以奇数顺序水平添加向量元素的最快方法是什么?

时间:2017-03-25 23:28:32

标签: optimization x86 simd intrinsics avx2

根据this question我实现了水平加法,这次是5乘5和7乘7.它正确地完成了工作,但速度不够快。 它能比它更快吗?我尝试使用hadd和其他指令,但改进受到限制。例如,当我使用_mm256_bsrli_epi128时,它稍微好一点,但它需要一些额外的排列,因为车道会破坏利益。所以问题是如何实现它以获得更多性能。同样的故事是9个元素等等。

这会水平添加5个元素,并将结果放在0,5和10位:

//it put the results in places 0, 5, and 10 
inline __m256i _mm256_hadd5x5_epi16(__m256i a  )
{
    __m256i a1, a2, a3, a4;

    a1 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 1 * 2);
    a2 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 2 * 2); 
    a3 = _mm256_bsrli_epi128(a2, 2);
    a4 = _mm256_bsrli_epi128(a3, 2);

    return _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(a1, a2), _mm256_add_epi16(a3, a4)) , a );
}

这会水平添加7个元素并将结果放在0和7位:

inline __m256i _mm256_hadd7x7_epi16(__m256i a  )
{
    __m256i a1, a2, a3, a4, a5, a6;

    a1 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 1 * 2);
    a2 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 2 * 2);
    a3 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 3 * 2);
    a4 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 4 * 2);
    a5 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 5 * 2);
    a6 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 6 * 2);

    return _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(a1, a2), _mm256_add_epi16(a3, a4)) ,  _mm256_add_epi16(_mm256_add_epi16(a5, a6), a ));
}

1 个答案:

答案 0 :(得分:2)

确实可以用较少的指令计算这些总和。这个想法是积累 部分和不仅在第10列,第5列和第0列中,而且在其他列中也是如此。这减少了数量 vpaddw说明以及与您的解决方案相比的“随机播放”数量。

#include <stdio.h>
#include <x86intrin.h>
/*  gcc -O3 -Wall -m64 -march=haswell hor_sum5x5.c   */
int print_vec_short(__m256i x);
int print_10_5_0_short(__m256i x);
__m256i _mm256_hadd5x5_epi16(__m256i a  );
__m256i _mm256_hadd7x7_epi16(__m256i a  );

int main() {
   short x[16];

   for(int i=0; i<16; i++) x[i] = i+1;   /* arbitrary initial values */

   __m256i t0     = _mm256_loadu_si256((__m256i*)x);                              
   __m256i t2     = _mm256_permutevar8x32_epi32(t0,_mm256_set_epi32(0,7,6,5,4,3,2,1));
   __m256i t02    = _mm256_add_epi16(t0,t2);
   __m256i t3     = _mm256_bsrli_epi128(t2,4);    /* byte shift right */
   __m256i t023   = _mm256_add_epi16(t02,t3);
   __m256i t13    = _mm256_srli_epi64(t02,16);    /* bit shift right  */
   __m256i sum    = _mm256_add_epi16(t023,t13);

   printf("t0  = ");print_vec_short(t0  );
   printf("t2  = ");print_vec_short(t2  );
   printf("t02 = ");print_vec_short(t02 );
   printf("t3  = ");print_vec_short(t3  );
   printf("t023= ");print_vec_short(t023);
   printf("t13 = ");print_vec_short(t13 );
   printf("sum = ");print_vec_short(sum );


   printf("\nVector elements of interest: columns  10, 5, 0:\n");
   printf("t0  [10, 5, 0]  = ");print_10_5_0_short(t0  );
   printf("t2  [10, 5, 0]  = ");print_10_5_0_short(t2  );
   printf("t02 [10, 5, 0]  = ");print_10_5_0_short(t02 );
   printf("t3  [10, 5, 0]  = ");print_10_5_0_short(t3  );
   printf("t023[10, 5, 0]  = ");print_10_5_0_short(t023);
   printf("t13 [10, 5, 0]  = ");print_10_5_0_short(t13 );
   printf("sum [10, 5, 0]  = ");print_10_5_0_short(sum );


   printf("\nSum with _mm256_hadd5x5_epi16(t0)\n");
   sum = _mm256_hadd5x5_epi16(t0);
   printf("sum [10, 5, 0]  = ");print_10_5_0_short(sum );

   /* now the sum of 7 elements: */
   printf("\n\nSum of short ints 13...7 and short ints 6...0:\n");

   __m256i t      = _mm256_loadu_si256((__m256i*)x);                              
           t0     = _mm256_permutevar8x32_epi32(t0,_mm256_set_epi32(3,6,5,4,3,2,1,0));
           t0     = _mm256_and_si256(t0,_mm256_set_epi16(0xFFFF,0,0xFFFF,0xFFFF,0xFFFF,0xFFFF,0xFFFF,0xFFFF,   0,0xFFFF,0xFFFF,0xFFFF,0xFFFF,0xFFFF,0xFFFF,0xFFFF));
   __m256i t1     = _mm256_alignr_epi8(t0,t0,2);
   __m256i t01    = _mm256_add_epi16(t0,t1);
   __m256i t23    = _mm256_alignr_epi8(t01,t01,4);
   __m256i t0123  = _mm256_add_epi16(t01,t23);
   __m256i t4567  = _mm256_alignr_epi8(t0123,t0123,8);
   __m256i sum08  = _mm256_add_epi16(t0123,t4567);      /* all elements are summed, but another permutation is needed to get the answer at position 7 */
           sum    = _mm256_permutevar8x32_epi32(sum08,_mm256_set_epi32(4,4,4,4,4,0,0,0));

   printf("t     = ");print_vec_short(t     );
   printf("t0    = ");print_vec_short(t0    );
   printf("t1    = ");print_vec_short(t1    );
   printf("t01   = ");print_vec_short(t01   );
   printf("t23   = ");print_vec_short(t23   );
   printf("t0123 = ");print_vec_short(t0123 );
   printf("t4567 = ");print_vec_short(t4567 );
   printf("sum08 = ");print_vec_short(sum08 );
   printf("sum   = ");print_vec_short(sum   );

   printf("\nSum with _mm256_hadd7x7_epi16(t)     (the answer is in column 0 and in column 7)\n");
   sum = _mm256_hadd7x7_epi16(t);
   printf("sum   = ");print_vec_short(sum   );


   return 0;
}



inline __m256i _mm256_hadd5x5_epi16(__m256i a  )
{
    __m256i a1, a2, a3, a4;

    a1 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 1 * 2);
    a2 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 2 * 2); 
    a3 = _mm256_bsrli_epi128(a2, 2);
    a4 = _mm256_bsrli_epi128(a3, 2);

    return _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(a1, a2), _mm256_add_epi16(a3, a4)) , a );
}


inline __m256i _mm256_hadd7x7_epi16(__m256i a  )
{
    __m256i a1, a2, a3, a4, a5, a6;

    a1 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 1 * 2);
    a2 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 2 * 2);
    a3 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 3 * 2);
    a4 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 4 * 2);
    a5 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 5 * 2);
    a6 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 6 * 2);

   return _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(a1, a2), _mm256_add_epi16(a3, a4)) ,  _mm256_add_epi16(_mm256_add_epi16(a5, a6), a ));
}


int print_vec_short(__m256i x){
   short int v[16];
   _mm256_storeu_si256((__m256i *)v,x);
   printf("%4hi %4hi %4hi %4hi | %4hi %4hi %4hi %4hi | %4hi %4hi %4hi %4hi  | %4hi %4hi %4hi %4hi \n",
          v[15],v[14],v[13],v[12],v[11],v[10],v[9],v[8],v[7],v[6],v[5],v[4],v[3],v[2],v[1],v[0]);
   return 0;
}

int print_10_5_0_short(__m256i x){
   short int v[16];
   _mm256_storeu_si256((__m256i *)v,x);
   printf("%4hi %4hi %4hi   \n",v[10],v[5],v[0]);
   return 0;
}

输出结果为:

$ ./a.out
t0  =   16   15   14   13 |   12   11   10    9 |    8    7    6    5  |    4    3    2    1 
t2  =    2    1   16   15 |   14   13   12   11 |   10    9    8    7  |    6    5    4    3 
t02 =   18   16   30   28 |   26   24   22   20 |   18   16   14   12  |   10    8    6    4 
t3  =    0    0    2    1 |   16   15   14   13 |    0    0   10    9  |    8    7    6    5 
t023=   18   16   32   29 |   42   39   36   33 |   18   16   24   21  |   18   15   12    9 
t13 =    0   18   16   30 |    0   26   24   22 |    0   18   16   14  |    0   10    8    6 
sum =   18   34   48   59 |   42   65   60   55 |   18   34   40   35  |   18   25   20   15 

Vector elements of interest: columns  10, 5, 0:
t0  [10, 5, 0]  =   11    6    1   
t2  [10, 5, 0]  =   13    8    3   
t02 [10, 5, 0]  =   24   14    4   
t3  [10, 5, 0]  =   15   10    5   
t023[10, 5, 0]  =   39   24    9   
t13 [10, 5, 0]  =   26   16    6   
sum [10, 5, 0]  =   65   40   15   

Sum with _mm256_hadd5x5_epi16(t0)
sum [10, 5, 0]  =   65   40   15   


Sum of short ints 13...7 and short ints 6...0:
t     =   16   15   14   13 |   12   11   10    9 |    8    7    6    5  |    4    3    2    1 
t0    =    8    0   14   13 |   12   11   10    9 |    0    7    6    5  |    4    3    2    1 
t1    =    9    8    0   14 |   13   12   11   10 |    1    0    7    6  |    5    4    3    2 
t01   =   17    8   14   27 |   25   23   21   19 |    1    7   13   11  |    9    7    5    3 
t23   =   21   19   17    8 |   14   27   25   23 |    5    3    1    7  |   13   11    9    7 
t0123 =   38   27   31   35 |   39   50   46   42 |    6   10   14   18  |   22   18   14   10 
t4567 =   39   50   46   42 |   38   27   31   35 |   22   18   14   10  |    6   10   14   18 
sum08 =   77   77   77   77 |   77   77   77   77 |   28   28   28   28  |   28   28   28   28 
sum   =   77   77   77   77 |   77   77   77   77 |   77   77   28   28  |   28   28   28   28 

Sum with _mm256_hadd7x7_epi16(t)     (the answer is in column 0 and in column 7)
sum   =   16   31   45   58 |   70   81   91   84 |   77   70   63   56  |   49   42   35   28