我试图优化以下代码(两个数组的平方差的总和):
inline float Square(float value)
{
return value*value;
}
float SquaredDifferenceSum(const float * a, const float * b, size_t size)
{
float sum = 0;
for(size_t i = 0; i < size; ++i)
sum += Square(a[i] - b[i]);
return sum;
}
所以我使用CPU的SSE指令进行了优化:
inline void SquaredDifferenceSum(const float * a, const float * b, size_t i, __m128 & sum)
{
__m128 _a = _mm_loadu_ps(a + i);
__m128 _b = _mm_loadu_ps(b + i);
__m128 _d = _mm_sub_ps(_a, _b);
sum = _mm_add_ps(sum, _mm_mul_ps(_d, _d));
}
inline float ExtractSum(__m128 a)
{
float _a[4];
_mm_storeu_ps(_a, a);
return _a[0] + _a[1] + _a[2] + _a[3];
}
float SquaredDifferenceSum(const float * a, const float * b, size_t size)
{
size_t i = 0, alignedSize = size/4*4;
__m128 sums = _mm_setzero_ps();
for(; i < alignedSize; i += 4)
SquaredDifferenceSum(a, b, i, sums);
float sum = ExtractSum(sums);
for(; i < size; ++i)
sum += Square(a[i] - b[i]);
return sum;
}
如果数组的大小不是太大,此代码可以正常工作。 但是如果大小足够大,则基本函数及其优化版本给出的结果之间存在大的计算错误。 所以我有一个问题:SSE优化代码中的错误在哪里导致计算错误。
答案 0 :(得分:7)
错误来自有限精度浮点数。 每增加两个浮点数就会产生与它们之间的差异成比例的计算误差。 在你的标量算法版本中,结果总和比每个术语大得多(如果数组的大小当然足够大)。 因此,它会导致大计算错误的累积。
在SSE版本的算法中,实际上有四个结果累加的总和。相对于标量代码,这些总和与每个术语之间的差异较小,为四倍。 因此,这会导致较小的计算错误。
有两种方法可以解决此错误:
1)使用双精度浮点数来累加和。
2)使用Kahan求和算法(也称为补偿求和),与明显的方法相比,它显着减少了通过添加有限精度浮点数序列而获得的总数中的数值误差。
https://en.wikipedia.org/wiki/Kahan_summation_algorithm
使用Kahan求和算法,您的标量代码将如下所示:
inline void KahanSum(float value, float & sum, float & correction)
{
float term = value - correction;
float temp = sum + term;
correction = (temp - sum) - term;
sum = temp;
}
float SquaredDifferenceKahanSum(const float * a, const float * b, size_t size)
{
float sum = 0, correction = 0;
for(size_t i = 0; i < size; ++i)
KahanSum(Square(a[i] - b[i]), sum, correction);
return sum;
}
SSE优化代码如下所示:
inline void SquaredDifferenceKahanSum(const float * a, const float * b, size_t i,
__m128 & sum, __m128 & correction)
{
__m128 _a = _mm_loadu_ps(a + i);
__m128 _b = _mm_loadu_ps(b + i);
__m128 _d = _mm_sub_ps(_a, _b);
__m128 term = _mm_sub_ps(_mm_mul_ps(_d, _d), correction);
__m128 temp = _mm_add_ps(sum, term);
correction = _mm_sub_ps(_mm_sub_ps(temp, sum), term);
sum = temp;
}
float SquaredDifferenceKahanSum(const float * a, const float * b, size_t size)
{
size_t i = 0, alignedSize = size/4*4;
__m128 sums = _mm_setzero_ps(), corrections = _mm_setzero_ps();
for(; i < alignedSize; i += 4)
SquaredDifferenceKahanSum(a, b, i, sums, corrections);
float sum = ExtractSum(sums), correction = 0;
for(; i < size; ++i)
KahanSum(Square(a[i] - b[i]), sum, correction);
return sum;
}