将SSE矩阵向量乘法码转换为AVX

时间:2015-11-21 16:44:54

标签: c++ sse simd avx avx2

我试图将我的SSE功能转换为AVX。该函数执行向量矩阵乘法,这是我的工作SSE代码:

void multiply_matrix_by_vector_SSE(float* m, float* v, float* result, unsigned const int vector_dims)
{
    size_t i, j;
    for (i = 0; i < vector_dims; ++i)
    {
        __m128 acc = _mm_setzero_ps();
        for (j = 0; j < vector_dims; j += 4)
        {
            __m128 vec = _mm_load_ps(&v[j]);
            __m128 mat = _mm_load_ps(&m[j + vector_dims * i]);
            //acc = _mm_add_ps(acc, _mm_mul_ps(mat, vec));
            acc = _mm_fmadd_ps(mat, vec, acc);
        }
        acc = _mm_hadd_ps(acc, acc);
        acc = _mm_hadd_ps(acc, acc);
        _mm_store_ss(&result[i], acc);
    }
}

以下是我为AVX提出的建议:

void multiply_matrix_by_vector_AVX(float* m, float* v, float* result, unsigned const int vector_dims)
{
    size_t i, j;

    for (i = 0; i < vector_dims; ++i)
    {
        __m256 acc = _mm256_setzero_ps();
        for (j = 0; j < vector_dims; j += 8)
        {
            __m256 vec = _mm256_load_ps(&v[j]);
            __m256 mat = _mm256_load_ps(&m[j + vector_dims * i]);
            acc = _mm256_fmadd_ps(mat, vec, acc);
        }
        acc = _mm256_hadd_ps(acc, acc);
        acc = _mm256_hadd_ps(acc, acc);
        acc = _mm256_hadd_ps(acc, acc);
        acc = _mm256_hadd_ps(acc, acc);

        _mm256_store_ps(&result[i], acc);
    }
}

然而,AVX代码崩溃(Access violation reading location 0xFFFFFFFFFFFFFFFF)。

有人可以帮助我让我的AVX功能正常工作吗?

PS:我在函数中传递的矩阵和向量的大小总是8的倍数。另外,传递给我的SSE函数的数组是16位对齐(__declspec(align(16))float* = generate_matrix(256);)和我通过的数组到我的AVX功能是32位对齐(__declspec(align(32))float* = generate_matrix(256););

1 个答案:

答案 0 :(得分:3)

不幸的是,使用水平添加不会轻易地扩展到256位,因为指令(和大多数其他)是“laned” - 它的作用就像两个haddps并行,一个在上半部分和一个在下半部分,没有混合,所以底部和上半部分不会相加在一起。

此外,它当然仍然不是打包结果,并且打包存储有一个对齐的存储写入一些未对齐的地址并且将失败(该错误有点奇怪,但无论如何)。

无论如何,让我们修正水平总和:(未经测试)

// this part still works
acc = _mm256_hadd_ps(acc, acc);
acc = _mm256_hadd_ps(acc, acc);
// this is new
__m128 acc1 = _mm256_extractf128_ps(acc, 0);
__m128 acc2 = _mm256_extractf128_ps(acc, 1);
acc1 = _mm_add_ss(acc1, acc2);
// do scalar store, obviously
_mm_store_ss(&result[i], acc1);

顺便说一下,内环需要10个独立的链(和10个累加器)才能最大化Haswell的吞吐量。