我的代码执行相同的操作,但AVX版本比SSE版本低得多。有人可以解释一下吗?
我已经做过的是我尝试使用VerySleepy来分析代码,但这无法给我任何有用的结果,它只是证实它更慢......
我已经查看了SSE / AVX指南和我的CPU(Haswell)中的命令,他们需要相同的延迟/吞吐量,只需要水平添加需要AVX的附加命令......
**延迟和吞吐量**
_mm_mul_ps -> L 5, T 0.5
_mm256_mul_ps -> L 5, T 0.5
_mm_hadd_ps -> L 5, T 2
_mm256_hadd_ps -> L 5, T ?
_mm256_extractf128_ps -> L 1, T 1
代码的作用摘要: Final1 = SUM(m_Array1 * m_Array1 * m_Array3 * m_Array3)
Final2 = SUM(m_Array2 * m_Array2 * m_Array3 * m_Array3)
Final3 = SUM(m_Array1 * m_Array2 * m_Array3 * m_Array3)
初始化
float Final1 = 0.0f;
float Final2 = 0.0f;
float Final3 = 0.0f;
float* m_Array1 = (float*)_mm_malloc( 32 * sizeof( float ), 32 );
float* m_Array2 = (float*)_mm_malloc( 32 * sizeof( float ), 32 );
float* m_Array3 = (float*)_mm_malloc( 32 * sizeof( float ), 32 );
SSE:
for ( int k = 0; k < 32; k += 4 )
{
__m128 g1 = _mm_load_ps( m_Array1 + k );
__m128 g2 = _mm_load_ps( m_Array2 + k );
__m128 g3 = _mm_load_ps( m_Array3 + k );
__m128 g1g3 = _mm_mul_ps( g1, g3 );
__m128 g2g3 = _mm_mul_ps( g2, g3 );
__m128 a1 = _mm_mul_ps( g1g3, g1g3 );
__m128 a2 = _mm_mul_ps( g2g3, g2g3 );
__m128 a3 = _mm_mul_ps( g1g3, g2g3 );
// horizontal add
{
a1 = _mm_hadd_ps( a1, a1 );
a1 = _mm_hadd_ps( a1, a1 );
Final1 += _mm_cvtss_f32( a1 );
a2 = _mm_hadd_ps( a2, a2 );
a2 = _mm_hadd_ps( a2, a2 );
Final2 += _mm_cvtss_f32( a2 );
a3 = _mm_hadd_ps( a3, a3 );
a3 = _mm_hadd_ps( a3, a3 );
Final3 += _mm_cvtss_f32( a3 );
}
}
AVX:
for ( int k = 0; k < 32; k += 8 )
{
__m256 g1 = _mm256_load_ps( m_Array1 + k );
__m256 g2 = _mm256_load_ps( m_Array2 + k );
__m256 g3 = _mm256_load_ps( m_Array3 + k );
__m256 g1g3 = _mm256_mul_ps( g1, g3 );
__m256 g2g3 = _mm256_mul_ps( g2, g3 );
__m256 a1 = _mm256_mul_ps( g1g3, g1g3 );
__m256 a2 = _mm256_mul_ps( g2g3, g2g3 );
__m256 a3 = _mm256_mul_ps( g1g3, g2g3 );
// horizontal add1
{
__m256 t1 = _mm256_hadd_ps( a1, a1 );
__m256 t2 = _mm256_hadd_ps( t1, t1 );
__m128 t3 = _mm256_extractf128_ps( t2, 1 );
__m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
Final1 += _mm_cvtss_f32( t4 );
}
// horizontal add2
{
__m256 t1 = _mm256_hadd_ps( a2, a2 );
__m256 t2 = _mm256_hadd_ps( t1, t1 );
__m128 t3 = _mm256_extractf128_ps( t2, 1 );
__m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
Final2 += _mm_cvtss_f32( t4 );
}
// horizontal add3
{
__m256 t1 = _mm256_hadd_ps( a3, a3 );
__m256 t2 = _mm256_hadd_ps( t1, t1 );
__m128 t3 = _mm256_extractf128_ps( t2, 1 );
__m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
Final3 += _mm_cvtss_f32( t4 );
}
}
答案 0 :(得分:5)
我拿了你的代码并把它放在一个测试工具中,编译它clang -O3
并计时。我还实现了两个例程的更快版本,水平添加移出循环:
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <sys/time.h> // gettimeofday
#include <immintrin.h>
static void sse(const float *m_Array1, const float *m_Array2, const float *m_Array3, size_t n, float *Final1, float *Final2, float *Final3)
{
*Final1 = *Final2 = *Final3 = 0.0f;
for (int k = 0; k < n; k += 4)
{
__m128 g1 = _mm_load_ps( m_Array1 + k );
__m128 g2 = _mm_load_ps( m_Array2 + k );
__m128 g3 = _mm_load_ps( m_Array3 + k );
__m128 g1g3 = _mm_mul_ps( g1, g3 );
__m128 g2g3 = _mm_mul_ps( g2, g3 );
__m128 a1 = _mm_mul_ps( g1g3, g1g3 );
__m128 a2 = _mm_mul_ps( g2g3, g2g3 );
__m128 a3 = _mm_mul_ps( g1g3, g2g3 );
// horizontal add
{
a1 = _mm_hadd_ps( a1, a1 );
a1 = _mm_hadd_ps( a1, a1 );
*Final1 += _mm_cvtss_f32( a1 );
a2 = _mm_hadd_ps( a2, a2 );
a2 = _mm_hadd_ps( a2, a2 );
*Final2 += _mm_cvtss_f32( a2 );
a3 = _mm_hadd_ps( a3, a3 );
a3 = _mm_hadd_ps( a3, a3 );
*Final3 += _mm_cvtss_f32( a3 );
}
}
}
static void sse_fast(const float *m_Array1, const float *m_Array2, const float *m_Array3, size_t n, float *Final1, float *Final2, float *Final3)
{
*Final1 = *Final2 = *Final3 = 0.0f;
__m128 a1 = _mm_setzero_ps();
__m128 a2 = _mm_setzero_ps();
__m128 a3 = _mm_setzero_ps();
for (int k = 0; k < n; k += 4)
{
__m128 g1 = _mm_load_ps( m_Array1 + k );
__m128 g2 = _mm_load_ps( m_Array2 + k );
__m128 g3 = _mm_load_ps( m_Array3 + k );
__m128 g1g3 = _mm_mul_ps( g1, g3 );
__m128 g2g3 = _mm_mul_ps( g2, g3 );
a1 = _mm_add_ps(a1, _mm_mul_ps( g1g3, g1g3 ));
a2 = _mm_add_ps(a2, _mm_mul_ps( g2g3, g2g3 ));
a3 = _mm_add_ps(a3, _mm_mul_ps( g1g3, g2g3 ));
}
// horizontal add
a1 = _mm_hadd_ps( a1, a1 );
a1 = _mm_hadd_ps( a1, a1 );
*Final1 += _mm_cvtss_f32( a1 );
a2 = _mm_hadd_ps( a2, a2 );
a2 = _mm_hadd_ps( a2, a2 );
*Final2 += _mm_cvtss_f32( a2 );
a3 = _mm_hadd_ps( a3, a3 );
a3 = _mm_hadd_ps( a3, a3 );
*Final3 += _mm_cvtss_f32( a3 );
}
static void avx(const float *m_Array1, const float *m_Array2, const float *m_Array3, size_t n, float *Final1, float *Final2, float *Final3)
{
*Final1 = *Final2 = *Final3 = 0.0f;
for (int k = 0; k < n; k += 8 )
{
__m256 g1 = _mm256_load_ps( m_Array1 + k );
__m256 g2 = _mm256_load_ps( m_Array2 + k );
__m256 g3 = _mm256_load_ps( m_Array3 + k );
__m256 g1g3 = _mm256_mul_ps( g1, g3 );
__m256 g2g3 = _mm256_mul_ps( g2, g3 );
__m256 a1 = _mm256_mul_ps( g1g3, g1g3 );
__m256 a2 = _mm256_mul_ps( g2g3, g2g3 );
__m256 a3 = _mm256_mul_ps( g1g3, g2g3 );
// horizontal add1
{
__m256 t1 = _mm256_hadd_ps( a1, a1 );
__m256 t2 = _mm256_hadd_ps( t1, t1 );
__m128 t3 = _mm256_extractf128_ps( t2, 1 );
__m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
*Final1 += _mm_cvtss_f32( t4 );
}
// horizontal add2
{
__m256 t1 = _mm256_hadd_ps( a2, a2 );
__m256 t2 = _mm256_hadd_ps( t1, t1 );
__m128 t3 = _mm256_extractf128_ps( t2, 1 );
__m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
*Final2 += _mm_cvtss_f32( t4 );
}
// horizontal add3
{
__m256 t1 = _mm256_hadd_ps( a3, a3 );
__m256 t2 = _mm256_hadd_ps( t1, t1 );
__m128 t3 = _mm256_extractf128_ps( t2, 1 );
__m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
*Final3 += _mm_cvtss_f32( t4 );
}
}
}
static void avx_fast(const float *m_Array1, const float *m_Array2, const float *m_Array3, size_t n, float *Final1, float *Final2, float *Final3)
{
*Final1 = *Final2 = *Final3 = 0.0f;
__m256 a1 = _mm256_setzero_ps();
__m256 a2 = _mm256_setzero_ps();
__m256 a3 = _mm256_setzero_ps();
for (int k = 0; k < n; k += 8 )
{
__m256 g1 = _mm256_load_ps( m_Array1 + k );
__m256 g2 = _mm256_load_ps( m_Array2 + k );
__m256 g3 = _mm256_load_ps( m_Array3 + k );
__m256 g1g3 = _mm256_mul_ps( g1, g3 );
__m256 g2g3 = _mm256_mul_ps( g2, g3 );
a1 = _mm256_add_ps(a1, _mm256_mul_ps( g1g3, g1g3 ));
a2 = _mm256_add_ps(a2, _mm256_mul_ps( g2g3, g2g3 ));
a3 = _mm256_add_ps(a3, _mm256_mul_ps( g1g3, g2g3 ));
}
// horizontal add1
{
__m256 t1 = _mm256_hadd_ps( a1, a1 );
__m256 t2 = _mm256_hadd_ps( t1, t1 );
__m128 t3 = _mm256_extractf128_ps( t2, 1 );
__m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
*Final1 += _mm_cvtss_f32( t4 );
}
// horizontal add2
{
__m256 t1 = _mm256_hadd_ps( a2, a2 );
__m256 t2 = _mm256_hadd_ps( t1, t1 );
__m128 t3 = _mm256_extractf128_ps( t2, 1 );
__m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
*Final2 += _mm_cvtss_f32( t4 );
}
// horizontal add3
{
__m256 t1 = _mm256_hadd_ps( a3, a3 );
__m256 t2 = _mm256_hadd_ps( t1, t1 );
__m128 t3 = _mm256_extractf128_ps( t2, 1 );
__m128 t4 = _mm_add_ss( _mm256_castps256_ps128( t2 ), t3 );
*Final3 += _mm_cvtss_f32( t4 );
}
}
int main(int argc, char *argv[])
{
size_t n = 4096;
if (argc > 1) n = atoi(argv[1]);
float *in_1 = valloc(n * sizeof(in_1[0]));
float *in_2 = valloc(n * sizeof(in_2[0]));
float *in_3 = valloc(n * sizeof(in_3[0]));
float out_1, out_2, out_3;
struct timeval t0, t1;
double t_ms;
for (int i = 0; i < n; ++i)
{
in_1[i] = (float)rand() / (float)(RAND_MAX / 2);
in_2[i] = (float)rand() / (float)(RAND_MAX / 2);
in_3[i] = (float)rand() / (float)(RAND_MAX / 2);
}
sse(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
printf("sse : %g, %g, %g\n", out_1, out_2, out_3);
sse_fast(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
printf("sse_fast: %g, %g, %g\n", out_1, out_2, out_3);
avx(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
printf("avx : %g, %g, %g\n", out_1, out_2, out_3);
avx_fast(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
printf("avx_fast: %g, %g, %g\n", out_1, out_2, out_3);
gettimeofday(&t0, NULL);
for (int k = 0; k < 100; ++k) sse(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
gettimeofday(&t1, NULL);
t_ms = ((double)(t1.tv_sec - t0.tv_sec) + (double)(t1.tv_usec - t0.tv_usec) * 1.0e-6) * 1.0e3;
printf("sse : %g, %g, %g, %g ms\n", out_1, out_2, out_3, t_ms);
gettimeofday(&t0, NULL);
for (int k = 0; k < 100; ++k) sse_fast(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
gettimeofday(&t1, NULL);
t_ms = ((double)(t1.tv_sec - t0.tv_sec) + (double)(t1.tv_usec - t0.tv_usec) * 1.0e-6) * 1.0e3;
printf("sse_fast: %g, %g, %g, %g ms\n", out_1, out_2, out_3, t_ms);
gettimeofday(&t0, NULL);
for (int k = 0; k < 100; ++k) avx(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
gettimeofday(&t1, NULL);
t_ms = ((double)(t1.tv_sec - t0.tv_sec) + (double)(t1.tv_usec - t0.tv_usec) * 1.0e-6) * 1.0e3;
printf("avx : %g, %g, %g, %g ms\n", out_1, out_2, out_3, t_ms);
gettimeofday(&t0, NULL);
for (int k = 0; k < 100; ++k) avx_fast(in_1, in_2, in_3, n, &out_1, &out_2, &out_3);
gettimeofday(&t1, NULL);
t_ms = ((double)(t1.tv_sec - t0.tv_sec) + (double)(t1.tv_usec - t0.tv_usec) * 1.0e-6) * 1.0e3;
printf("avx_fast: %g, %g, %g, %g ms\n", out_1, out_2, out_3, t_ms);
return 0;
}
我的2.6 GHz Haswell(MacBook Pro)上的结果是:
sse : 0.439 ms
sse_fast: 0.153 ms
avx : 0.309 ms
avx_fast: 0.085 ms
因此,对于原始实现和优化实现,AVX版本确实看起来比SSE版本更快。优化的实现速度明显快于原始版本,但幅度更大。
我只能猜测你的编译器是不是为AVX生成了非常好的代码(或者你忘了启用编译器优化?),或者对你的基准测试方法有疑问。