根据this question我实现了水平加法,这次是5乘5和7乘7.它正确地完成了工作,但速度不够快。
它能比它更快吗?我尝试使用hadd
和其他指令,但改进受到限制。例如,当我使用_mm256_bsrli_epi128
时,它稍微好一点,但它需要一些额外的排列,因为车道会破坏利益。所以问题是如何实现它以获得更多性能。同样的故事是9个元素等等。
这会水平添加5个元素,并将结果放在0,5和10位:
//it put the results in places 0, 5, and 10
inline __m256i _mm256_hadd5x5_epi16(__m256i a )
{
__m256i a1, a2, a3, a4;
a1 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 1 * 2);
a2 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 2 * 2);
a3 = _mm256_bsrli_epi128(a2, 2);
a4 = _mm256_bsrli_epi128(a3, 2);
return _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(a1, a2), _mm256_add_epi16(a3, a4)) , a );
}
这会水平添加7个元素并将结果放在0和7位:
inline __m256i _mm256_hadd7x7_epi16(__m256i a )
{
__m256i a1, a2, a3, a4, a5, a6;
a1 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 1 * 2);
a2 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 2 * 2);
a3 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 3 * 2);
a4 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 4 * 2);
a5 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 5 * 2);
a6 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 6 * 2);
return _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(a1, a2), _mm256_add_epi16(a3, a4)) , _mm256_add_epi16(_mm256_add_epi16(a5, a6), a ));
}
答案 0 :(得分:2)
确实可以用较少的指令计算这些总和。这个想法是积累
部分和不仅在第10列,第5列和第0列中,而且在其他列中也是如此。这减少了数量
vpaddw
说明以及与您的解决方案相比的“随机播放”数量。
#include <stdio.h>
#include <x86intrin.h>
/* gcc -O3 -Wall -m64 -march=haswell hor_sum5x5.c */
int print_vec_short(__m256i x);
int print_10_5_0_short(__m256i x);
__m256i _mm256_hadd5x5_epi16(__m256i a );
__m256i _mm256_hadd7x7_epi16(__m256i a );
int main() {
short x[16];
for(int i=0; i<16; i++) x[i] = i+1; /* arbitrary initial values */
__m256i t0 = _mm256_loadu_si256((__m256i*)x);
__m256i t2 = _mm256_permutevar8x32_epi32(t0,_mm256_set_epi32(0,7,6,5,4,3,2,1));
__m256i t02 = _mm256_add_epi16(t0,t2);
__m256i t3 = _mm256_bsrli_epi128(t2,4); /* byte shift right */
__m256i t023 = _mm256_add_epi16(t02,t3);
__m256i t13 = _mm256_srli_epi64(t02,16); /* bit shift right */
__m256i sum = _mm256_add_epi16(t023,t13);
printf("t0 = ");print_vec_short(t0 );
printf("t2 = ");print_vec_short(t2 );
printf("t02 = ");print_vec_short(t02 );
printf("t3 = ");print_vec_short(t3 );
printf("t023= ");print_vec_short(t023);
printf("t13 = ");print_vec_short(t13 );
printf("sum = ");print_vec_short(sum );
printf("\nVector elements of interest: columns 10, 5, 0:\n");
printf("t0 [10, 5, 0] = ");print_10_5_0_short(t0 );
printf("t2 [10, 5, 0] = ");print_10_5_0_short(t2 );
printf("t02 [10, 5, 0] = ");print_10_5_0_short(t02 );
printf("t3 [10, 5, 0] = ");print_10_5_0_short(t3 );
printf("t023[10, 5, 0] = ");print_10_5_0_short(t023);
printf("t13 [10, 5, 0] = ");print_10_5_0_short(t13 );
printf("sum [10, 5, 0] = ");print_10_5_0_short(sum );
printf("\nSum with _mm256_hadd5x5_epi16(t0)\n");
sum = _mm256_hadd5x5_epi16(t0);
printf("sum [10, 5, 0] = ");print_10_5_0_short(sum );
/* now the sum of 7 elements: */
printf("\n\nSum of short ints 13...7 and short ints 6...0:\n");
__m256i t = _mm256_loadu_si256((__m256i*)x);
t0 = _mm256_permutevar8x32_epi32(t0,_mm256_set_epi32(3,6,5,4,3,2,1,0));
t0 = _mm256_and_si256(t0,_mm256_set_epi16(0xFFFF,0,0xFFFF,0xFFFF,0xFFFF,0xFFFF,0xFFFF,0xFFFF, 0,0xFFFF,0xFFFF,0xFFFF,0xFFFF,0xFFFF,0xFFFF,0xFFFF));
__m256i t1 = _mm256_alignr_epi8(t0,t0,2);
__m256i t01 = _mm256_add_epi16(t0,t1);
__m256i t23 = _mm256_alignr_epi8(t01,t01,4);
__m256i t0123 = _mm256_add_epi16(t01,t23);
__m256i t4567 = _mm256_alignr_epi8(t0123,t0123,8);
__m256i sum08 = _mm256_add_epi16(t0123,t4567); /* all elements are summed, but another permutation is needed to get the answer at position 7 */
sum = _mm256_permutevar8x32_epi32(sum08,_mm256_set_epi32(4,4,4,4,4,0,0,0));
printf("t = ");print_vec_short(t );
printf("t0 = ");print_vec_short(t0 );
printf("t1 = ");print_vec_short(t1 );
printf("t01 = ");print_vec_short(t01 );
printf("t23 = ");print_vec_short(t23 );
printf("t0123 = ");print_vec_short(t0123 );
printf("t4567 = ");print_vec_short(t4567 );
printf("sum08 = ");print_vec_short(sum08 );
printf("sum = ");print_vec_short(sum );
printf("\nSum with _mm256_hadd7x7_epi16(t) (the answer is in column 0 and in column 7)\n");
sum = _mm256_hadd7x7_epi16(t);
printf("sum = ");print_vec_short(sum );
return 0;
}
inline __m256i _mm256_hadd5x5_epi16(__m256i a )
{
__m256i a1, a2, a3, a4;
a1 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 1 * 2);
a2 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 2 * 2);
a3 = _mm256_bsrli_epi128(a2, 2);
a4 = _mm256_bsrli_epi128(a3, 2);
return _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(a1, a2), _mm256_add_epi16(a3, a4)) , a );
}
inline __m256i _mm256_hadd7x7_epi16(__m256i a )
{
__m256i a1, a2, a3, a4, a5, a6;
a1 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 1 * 2);
a2 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 2 * 2);
a3 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 3 * 2);
a4 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 4 * 2);
a5 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 5 * 2);
a6 = _mm256_alignr_epi8(_mm256_permute2x128_si256(a, _mm256_setzero_si256(), 0x31), a, 6 * 2);
return _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(a1, a2), _mm256_add_epi16(a3, a4)) , _mm256_add_epi16(_mm256_add_epi16(a5, a6), a ));
}
int print_vec_short(__m256i x){
short int v[16];
_mm256_storeu_si256((__m256i *)v,x);
printf("%4hi %4hi %4hi %4hi | %4hi %4hi %4hi %4hi | %4hi %4hi %4hi %4hi | %4hi %4hi %4hi %4hi \n",
v[15],v[14],v[13],v[12],v[11],v[10],v[9],v[8],v[7],v[6],v[5],v[4],v[3],v[2],v[1],v[0]);
return 0;
}
int print_10_5_0_short(__m256i x){
short int v[16];
_mm256_storeu_si256((__m256i *)v,x);
printf("%4hi %4hi %4hi \n",v[10],v[5],v[0]);
return 0;
}
输出结果为:
$ ./a.out
t0 = 16 15 14 13 | 12 11 10 9 | 8 7 6 5 | 4 3 2 1
t2 = 2 1 16 15 | 14 13 12 11 | 10 9 8 7 | 6 5 4 3
t02 = 18 16 30 28 | 26 24 22 20 | 18 16 14 12 | 10 8 6 4
t3 = 0 0 2 1 | 16 15 14 13 | 0 0 10 9 | 8 7 6 5
t023= 18 16 32 29 | 42 39 36 33 | 18 16 24 21 | 18 15 12 9
t13 = 0 18 16 30 | 0 26 24 22 | 0 18 16 14 | 0 10 8 6
sum = 18 34 48 59 | 42 65 60 55 | 18 34 40 35 | 18 25 20 15
Vector elements of interest: columns 10, 5, 0:
t0 [10, 5, 0] = 11 6 1
t2 [10, 5, 0] = 13 8 3
t02 [10, 5, 0] = 24 14 4
t3 [10, 5, 0] = 15 10 5
t023[10, 5, 0] = 39 24 9
t13 [10, 5, 0] = 26 16 6
sum [10, 5, 0] = 65 40 15
Sum with _mm256_hadd5x5_epi16(t0)
sum [10, 5, 0] = 65 40 15
Sum of short ints 13...7 and short ints 6...0:
t = 16 15 14 13 | 12 11 10 9 | 8 7 6 5 | 4 3 2 1
t0 = 8 0 14 13 | 12 11 10 9 | 0 7 6 5 | 4 3 2 1
t1 = 9 8 0 14 | 13 12 11 10 | 1 0 7 6 | 5 4 3 2
t01 = 17 8 14 27 | 25 23 21 19 | 1 7 13 11 | 9 7 5 3
t23 = 21 19 17 8 | 14 27 25 23 | 5 3 1 7 | 13 11 9 7
t0123 = 38 27 31 35 | 39 50 46 42 | 6 10 14 18 | 22 18 14 10
t4567 = 39 50 46 42 | 38 27 31 35 | 22 18 14 10 | 6 10 14 18
sum08 = 77 77 77 77 | 77 77 77 77 | 28 28 28 28 | 28 28 28 28
sum = 77 77 77 77 | 77 77 77 77 | 77 77 28 28 | 28 28 28 28
Sum with _mm256_hadd7x7_epi16(t) (the answer is in column 0 and in column 7)
sum = 16 31 45 58 | 70 81 91 84 | 77 70 63 56 | 49 42 35 28