AVX内在澄清,4x4矩阵乘法奇数

时间:2017-03-23 14:32:10

标签: c++ c algorithm avx

在纸面上,我提出了这种算法的长形式,在纸上它应该可以正常工作。我是否在使用寄存器转换(256/128/256)时遇到了一个微妙的问题,或者我是否真的在某处弄乱了算法结构?

为方便起见,我将香草代码和AVX代码放在Godbolt查看器上,这样您就可以随意查看生成的程序集。

标准代码 https://godbolt.org/g/v47RKH

我的AVX尝试1: https://godbolt.org/g/oH1DpO

我的AVX尝试2: https://godbolt.org/g/QFtdKr(剃光了5个周期,减少了施法需求,更容易阅读)

SSE代码奇怪的是使用标量操作,这令人难以置信,因为水平广播,muls和添加肯定可以加速。我想要做的就是把这个概念提升到一个层次。

RHS永远不需要改变,但基本上如果LHS是{a,b,...,p}, 并且LHS是{1,2,...,16},那么我们只需要2个寄存器来保存RHS的两半,然后需要2个寄存器来保存给定行的LHS,形式为{a,a,a,a ,b,b,b,b}和{c,c,c,c,d,d,d,d}。这是通过2个广播和256/128/256演员实现的。

我们得到

的中间结果
  

{a * 1,a * 2,a * 3,a * 4,b * 5,b * 6,b * 7,b * 8} =>行[0]

  

{c * 9,c * 10,c * 11,c * 12,d * 13,d * 14,d * 15,d * 16} =>行[1]

这是在LHS一次展开的,所以我们生成了

  

{e * 1,... f * 8},{g * 9,... h * 16} => row [2],row [3]

接下来将r0,r1和r2,r3加在一起(保持r0和r2为当前中间体)

最后,将行[0]的高半部分提取到resHalf的低半部分,将行[2]的低半部分插入resHalf的高半部分,将行[2]的高半部分插入高半部分的行[0],然后将行[0]添加到resHalf。

根据所有权利,这应该让我们在迭代结束时使用resHalf [0]等于以下内容i = 0

  

{a * 1 + b * 2 + c * 3 + d * 4,a * 5 + b * 6 + c * 7 + d * 8,

     

a * 9 + b * 10 + c * 11 + d * 12,a * 13 + b * 14 + c * 15 + d * 16,

     

e * 1 + ... + h * 4,e * 5 + ... + h * 8,

     

e * 9 + ... + h * 12,e * 13 + ... + h * 16}

我的算法产生的是以下内容:

  

2x {a * 1 + c * 3,a * 5 + c * 7,a * 9 + c * 11,a * 13 + c * 15},

     

2x {e * 1 + g * 3,e * 5 + g * 7,e * 9 + g * 11,e * 13 + g * 15}

如果我在三元条件中交换rhsHolders [0/1],那么更可怕的是,它根本不会改变结果。这就好像编译器忽略了其中一个交换和添加。 Clang 4和GCC 7都这样做,所以我在哪里搞砸了?

编辑:输出应该是4行{10,26,42,58},但我得到{4,12,20,28}

2 个答案:

答案 0 :(得分:1)

  

SSE代码奇怪的是使用标量操作,这让我大吃一惊,因为水平广播,muls和添加肯定可以加速。

你的意思是编译器生成的汇编代码? clang4.0和gcc7.1输出中MatMul()中的所有AVX指令都在ymm向量上运行。除了铿锵的愚蠢的广播负载:它执行标量加载,然后是单独的AVX2广播指令,这是非常糟糕的,因为英特尔CPU将广播负载作为单uop ALU指令处理。加载端口本身可以进行广播。但是如果源是一个寄存器,它需要一个ALU uop用于shuffle端口。

    vmovss  xmm5, dword ptr [rdi + 24] # xmm5 = mem[0],zero,zero,zero
    vbroadcastss    xmm5, xmm5

clang的实际输出(上图)与gcc使用的AVX1 vbroadcastss xmm5, [rdi + 24]相比真的很傻。

main()中,clang会发出标量操作

由于您的输入矩阵都是编译时常量,唯一的谜团就是为什么它没有优化到cout << "a long string with the numbers already formatted\n";,或者至少优化掉所有数学并且只有double结果已准备好进行打印。 (是的,他们正在使用float在打印循环中从double转换为vcvtss2sd。)

它通过一些内在的shuffle和数学进行优化,在编译时完成它们。我猜clang在洗牌的某个地方迷路了,仍然发出了一些数学运算。它们是标量的事实可能表明它在编译时没有做太多工作,但它并没有重新排序以对其进行矢量化。

请注意,某些常量不会出现在源代码中,并且它们在内存中不是按升序排列的。

...
.LCPI1_5:
        .long   1092616192              # float 10
.LCPI1_6:
        .long   1101004800              # float 20
.LCPI1_7:
        .long   1098907648              # float 16
...

在位模式的整数表示之后,clang将float值放在注释中真的很好。

  

或者我是否真的搞砸了某处的算法结构?

嗯,这部分实现看起来完全是假的。您从lowerHalf初始化rows[j],但在下一个语句中覆盖该值。

__m128 lowerHalf = _mm256_castps256_ps128(rows[j]);
    lowerHalf = _mm_broadcast_ss(&lhs[offset+2*j]);

然后你用rows[j]未定义的上部128b通道进行256b乘法运算。

    rows[j] = _mm256_castps128_ps256(lowerHalf);
    rows[j] = _mm256_mul_ps(rows[j], (chooser) ? rhsHolders[0] : rhsHolders[1]);

在来自gcc和clang的asm中,上部通道全部为零(因为它们明显选择使用最后由标量写的ymm寄存器 - &gt; xmm广播,隐式零延伸到最大矢量宽度)。请注意_mm256_castps128_ps256无法保证零扩展。除非__m128本身是256b或更宽向量的提取/强制转换的结果,否则它很可能是未定义的。有关在向量中需要归零上层通道的情况,请参阅How to clear the upper 128 bits of __m256 value?

无论如何,这意味着你从128b向量乘法(vmulps xmm, xmm, xmm)得到相同的结果:在这些指令之后,高4个元素都将为零(或NaN)

    vbroadcastss    xmm0, DWORD PTR [rdi+40]
    vmulps  ymm0, ymm2, ymm0

这种asm输出(来自gcc7.1)极不可能成为正确matmul实现的一部分。

我没有仔细查看你在源头上想要做什么,但我认为这不是这个。

  

如果我在三元条件中交换rhsHolders [0/1],那么更可怕的是,它根本不会改变结果。好像编译器忽略了其中一个交换和添加。

当更改源中的某些内容时,不会产生您在asm输出中所期望的更改,这暗示您可能会错误地找到源,并且正在优化某些内容。有时候我会复制/粘贴一个内在函数而忘记在新行中更改输入变量,所以我的函数会忽略它的一些计算结果并使用另一个两次。

答案 1 :(得分:0)

它几乎可以复制并粘贴我昨天在SO上的答案:)

试试这个

void MatMul(const float* __restrict lhs , const float* __restrict rhs , float* __restrict out ) 
{
  lhs = reinterpret_cast<float*>(__builtin_assume_aligned (lhs, 32));
  rhs = reinterpret_cast<float*>(__builtin_assume_aligned (rhs, 32));
  out = reinterpret_cast<float*>(__builtin_assume_aligned (out, 32));
  for(int i = 0; i < 4; i++){
    for(int j = 0; j < 4; j++){
      for (int k = 0; k < 4; k++){
        out[i*4 + j] += lhs[i*4 + k]*rhs[k*4 + i];
      }
    }     
  }     
}

使用下面的一个编译(测量哪一个最快)

-O3 -mavx
-O3 -mavx2
-O3 -mavx2 -mfma
-O3 -mavx2 -mfma -ffast-math

这在GCC下工作(我的意思是矢量化),cLANG由于某种原因没有这样做。 GCC也将展开循环。