如何从AVX寄存器中获取数据?

时间:2016-06-03 10:51:49

标签: c++ visual-c++ avx fma

使用MSVC 2013和AVX 1,我在寄存器中有8个浮点数:

__m256 foo = mm256_fmadd_ps(a,b,c);

现在我想为所有8个花车调用inline void print(float) {...}。看起来像英特尔 AVX的内在因素会让这个变得相当复杂:

print(_castu32_f32(_mm256_extract_epi32(foo, 0)));
print(_castu32_f32(_mm256_extract_epi32(foo, 1)));
print(_castu32_f32(_mm256_extract_epi32(foo, 2)));
// ...

但MSVC甚至没有这两种内在函数。当然,我可以将值写回内存并从那里加载,但我怀疑在汇编级别没有必要溢出寄存器。

奖金问:我当然喜欢写

for(int i = 0; i !=8; ++i) 
    print(_castu32_f32(_mm256_extract_epi32(foo, i)))

但MSVC并不了解许多内在函数需要循环展开。如何在__m256 foo

中的8x32浮点数上写一个循环

4 个答案:

答案 0 :(得分:4)

假设你只有AVX(即没有AVX2)那么你可以这样做:

float extract_float(const __m128 v, const int i)
{
    float x;
    _MM_EXTRACT_FLOAT(x, v, i);
    return x;
}

void print(const __m128 v)
{
    print(extract_float(v, 0));
    print(extract_float(v, 1));
    print(extract_float(v, 2));
    print(extract_float(v, 3));
}

void print(const __m256 v)
{
    print(_mm256_extractf128_ps(v, 0));
    print(_mm256_extractf128_ps(v, 1));
}

但是我想我可能只会使用一个联盟:

union U256f {
    __m256 v;
    float a[8];
};

void print(const __m256 v)
{
    const U256f u = { v };

    for (int i = 0; i < 8; ++i)
        print(u.a[i]);
}

答案 1 :(得分:3)

小心:_mm256_fmadd_ps不属于AVX1。 FMA3有自己的功能位,仅在英特尔与Haswell上推出。 AMD推出带有Piledriver的FMA3(AVX1 + FMA4 + FMA3,无AVX2)。

在asm级别,如果要将8个32位元素放入整数寄存器,实际上存储到堆栈然后执行标量加载会更快。 pextrd是关于SnB家族和Bulldozer家族的2-uop指令。 (以及Nehalem和Silvermont,不支持AVX)。

唯一一个vextractf128 + 2x movd + 6x pextrd并不可怕的CPU是AMD Jaguar。 (便宜pextrd,只有一个加载端口。)(参见Agner Fog's insn tables

宽对齐的商店可以转发到重叠的窄负载。 (当然,您可以使用movd来获取低元素,因此您可以混合使用加载端口和ALU端口uops。

当然,你似乎是通过使用整数提取然后将其转换回浮点数来提取float这看起来很糟糕。

你真正需要的是它自己的xmm寄存器的低元素中的每个floatvextractf128显然是开始的方式,将元素4带到新的xmm reg的底部。那么6x AVX shufps可以轻松获得每一半的其他三个元素。 (或movshdupmovhlps编码较短:没有立即字节。

7个shuffle uops值得考虑,1个商店和7个负载uops,但是如果你打算溢出函数调用的向量,那就不行了。

ABI注意事项:

你在Windows上,其中xmm6-15被调用保留(只有low128; ymm6-15的上半部分是call-clobbered)。这是从vextractf128开始的另一个原因。

在SysV ABI中,所有xmm / ymm / zmm寄存器都是call-clobbered,因此每个print()函数都需要溢出/重新加载。唯一明智的做法是存储到内存并使用原始向量调用print(即打印低元素,因为它将忽略寄存器的其余部分)。然后movss xmm0, [rsp+4]并在第二个元素上调用print

将8个浮点数很好地解压缩到8个向量寄存器中没有任何好处,因为在第一次函数调用之前它们都必须单独溢出!

答案 2 :(得分:1)

(未完成的答案。无论如何都要张贴以防万一,或者我回到它。一般来说,如果你需要与标量接口,你无法进行矢量化,那么它不是不好将矢量存储到本地数组,然后一次重新加载一个元素。)

请参阅我对asm详细信息的其他答案。这个答案是关于C ++方面的。

使用Agner Fog's Vector Class Library,他的包装类重载operator[]以完全按照您预期的方式工作,即使对于非常数args也是如此。这通常编译为存储/重新加载,但它使得用C ++编写代码变得容易。启用优化后,您可能会获得不错的结果。 (除了低元素可能会被存储/重新加载,而不是仅仅使用到位。所以你可能需要特殊情况vec[0]_mm_cvtss_f32(vec)或其他东西。)

另请参阅我的github repo,其中对Agner的VCL进行了大多数未经测试的更改,以便为某些功能生成更好的代码。

有一个_MM_EXTRACT_FLOAT wrapper macro,但它很奇怪,只能用SSE4.1定义。我认为它打算使用SSE4.1 extractps(它可以将float的二进制表示提取到整数寄存器中,或存储到内存中)。不过,当目的地为float时,gcc会将其编译为FP shuffle。如果您希望结果为extractps,请注意其他编译器不要将其编译为实际float指令,因为that's not what extractps会这样做。 (那就是insertps does,但更简单的FP shuffle会占用更少的指令字节。例如shufps与AVX相比很棒。)

这很奇怪,因为它需要3个参数:_MM_EXTRACT_FLOAT(dest, src_m128, idx),因此您甚至无法将其用作float本地的初始化程序。

循环向量

gcc将为您展开这样的循环,但仅限于-O1或更高版本。在-O0,它会显示错误消息。

float bad_hsum(__m128 & fv) {
    float sum = 0;
    for (int i=0 ; i<4 ; i++) {
        float f;
        _MM_EXTRACT_FLOAT(f, fv, i);  // works only with -O1 or higher
        sum += f;
    }
    return sum;
}

答案 3 :(得分:1)

    float valueAVX(__m256 a, int i){

        float ret = 0;
        switch (i){

            case 0:
//                 a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)      ( a3, a2, a1, a0 )
// cvtss_f32             a0 

                ret = _mm_cvtss_f32(_mm256_extractf128_ps(a, 0));
                break;
            case 1: {
//                     a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)     lo = ( a3, a2, a1, a0 )
// shuffle(lo, lo, 1)      ( - , a3, a2, a1 )
// cvtss_f32                 a1 
                __m128 lo = _mm256_extractf128_ps(a, 0);
                ret = _mm_cvtss_f32(_mm_shuffle_ps(lo, lo, 1));
            }
                break;
            case 2: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)   lo = ( a3, a2, a1, a0 )
// movehl(lo, lo)        ( - , - , a3, a2 )
// cvtss_f32               a2 
                __m128 lo = _mm256_extractf128_ps(a, 0);
                ret = _mm_cvtss_f32(_mm_movehl_ps(lo, lo));
            }
                break;
            case 3: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 0)   lo = ( a3, a2, a1, a0 )
// shuffle(lo, lo, 3)    ( - , - , - , a3 )
// cvtss_f32               a3 
                __m128 lo = _mm256_extractf128_ps(a, 0);                    
                ret = _mm_cvtss_f32(_mm_shuffle_ps(lo, lo, 3));
            }
                break;

            case 4:
//                 a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)      ( a7, a6, a5, a4 )
// cvtss_f32             a4 
                ret = _mm_cvtss_f32(_mm256_extractf128_ps(a, 1));
                break;
            case 5: {
//                     a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)     hi = ( a7, a6, a5, a4 )
// shuffle(hi, hi, 1)      ( - , a7, a6, a5 )
// cvtss_f32                 a5 
                __m128 hi = _mm256_extractf128_ps(a, 1);
                ret = _mm_cvtss_f32(_mm_shuffle_ps(hi, hi, 1));
            }
                break;
            case 6: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)   hi = ( a7, a6, a5, a4 )
// movehl(hi, hi)        ( - , - , a7, a6 )
// cvtss_f32               a6 
                __m128 hi = _mm256_extractf128_ps(a, 1);
                ret = _mm_cvtss_f32(_mm_movehl_ps(hi, hi));
            }
                break;
            case 7: {
//                   a = ( a7, a6, a5, a4, a3, a2, a1, a0 )
// extractf(a, 1)   hi = ( a7, a6, a5, a4 )
// shuffle(hi, hi, 3)    ( - , - , - , a7 )
// cvtss_f32               a7 
                __m128 hi = _mm256_extractf128_ps(a, 1);
                ret = _mm_cvtss_f32(_mm_shuffle_ps(hi, hi, 3));
            }
                break;
        }

        return ret;
    }