如何优化长系列的If / then条件表达式 - SIMD

时间:2015-07-31 09:56:20

标签: c++ c optimization conditional simd

我使用SIMD来提高C代码的性能,但我遇到了一个带有很多if / then条件的函数,如下所示:

if (Di <= -T3) return  -4;
if (Di <= -T2) return  -3;
if (Di <= -T1) return  -2;
if (Di < -NEAR)  return  -1;
if (Di <=  NEAR) return   0;
if (Di < T1)   return   1;
if (Di < T2)   return   2;
if (Di < T3)   return   3;

return  4;

使用VC ++编译器支持的Intel内部函数可以缩短处理时间。

那么有没有更好的方法来优化这一长串的条件表达式?

2 个答案:

答案 0 :(得分:6)

我假设了几件事:

  1. 您处理int32数据(但很容易将其更改为float32)。
  2. 您可以一次将4个值传递给您的函数(不只是一个)。这就是人们通常所说的矢量化
  3. 对常数进行排序,即0&lt; NEAR&lt; T1&lt; T2&lt; T3。
  4. 这是一个矢量化函数:

    __m128i func4(__m128i D) {
      __m128i cmp_m3 = _mm_cmpgt_epi32(D, _mm_set1_epi32(-T3));
      __m128i cmp_m2 = _mm_cmpgt_epi32(D, _mm_set1_epi32(-T2));
      __m128i cmp_m1 = _mm_cmpgt_epi32(D, _mm_set1_epi32(-T1));
      __m128i cmp_p0 = _mm_cmpgt_epi32(D, _mm_set1_epi32(NEAR));
      __m128i reduce_true = _mm_add_epi32(_mm_add_epi32(cmp_m3, cmp_m2), _mm_add_epi32(cmp_m1, cmp_p0));
      __m128i cmp_m0 = _mm_cmplt_epi32(D, _mm_set1_epi32(-NEAR));
      __m128i cmp_p1 = _mm_cmplt_epi32(D, _mm_set1_epi32(T1));
      __m128i cmp_p2 = _mm_cmplt_epi32(D, _mm_set1_epi32(T2));
      __m128i cmp_p3 = _mm_cmplt_epi32(D, _mm_set1_epi32(T3));
      __m128i reduce_false = _mm_add_epi32(_mm_add_epi32(cmp_p3, cmp_p2), _mm_add_epi32(cmp_p1, cmp_m0));
      return _mm_sub_epi32(reduce_false, reduce_true);
    }
    

    如果输入数据是随机的,那么它在使用MSVC2013 x64的Ivy Bridge上比原始版本快11倍:

    Time = 4.436   (-39910000)
    Time = 0.409   (-39910000)
    

    可以使用完整的测试代码here

    这个想法很简单。 您可以在上面的链接后面的函数funcX中看到建议解决方案的非矢量化版本。它可以解释一切比文字更好。

    我们将寄存器D作为输入,它包含4个打包值。 然后我们将它与_mm_cmp*内在的所有8个常数进行比较。此比较产生8位掩码cmp_pXcmp_mX。在位掩码中,对应于数字的所有位都是0或1。为每次比较设置32个零位,这是错误的。如果比较条件为真,则将32位设置为1.

    现在回想一下,带符号表示的所有一位的32位整数是-1。当我们将四个比较结果加在一起时,我们得到一组否定的计数。最后,我们采用两个计数的差异,这是期望的结果。

    P.S。这是为内循环生成的汇编代码:

    movdqa  xmm3, XMMWORD PTR [rcx]
    movdqa  xmm4, xmm10
    movdqa  xmm0, xmm9
    add rcx, 16
    pcmpgtd xmm0, xmm3
    pcmpgtd xmm4, xmm3
    paddd   xmm4, xmm0
    movdqa  xmm2, xmm3
    movdqa  xmm1, xmm8
    pcmpgtd xmm1, xmm3
    pcmpgtd xmm2, xmm14
    movdqa  xmm0, xmm7
    pcmpgtd xmm0, xmm3
    paddd   xmm1, xmm0
    paddd   xmm4, xmm1
    movdqa  xmm0, xmm3
    movdqa  xmm1, xmm3
    pcmpgtd xmm1, xmm12
    pcmpgtd xmm0, xmm13
    pcmpgtd xmm3, xmm11
    paddd   xmm1, xmm3
    paddd   xmm2, xmm0
    paddd   xmm2, xmm1
    psubd   xmm4, xmm2
    paddd   xmm4, xmm5
    movdqa  xmm5, xmm4
    cmp rcx, r15
    jl  SHORT $LL3@main
    

答案 1 :(得分:0)

您可以尝试完全摆脱条件并再次测量时间。 你的代码

if (Di <= -T3) return  -4;
if (Di <= -T2) return  -3;
if (Di <= -T1) return  -2;
if (Di < -NEAR)  return  -1;
if (Di <=  NEAR) return   0;
if (Di < T1)   return   1;
if (Di < T2)   return   2;
if (Di < T3)   return   3;

return  4;

可以转换为无条件形式:

return
    (Di <= -T3)*(-4) + (Di > -T3) * (
    (Di <= -T2)*(-3) + (Di > -T2) * (
    (Di <= -T1)*(-2) + (Di > -T1) * (
    (Di < -NEAR)*(-1) + (Di >= -NEAR) * (
    (Di <=  NEAR)*0 + (Di > NEAR) * (
    (Di < T1)*1 + (Di >= T1) * (
    (Di < T2)*2 + (Di >= T2) * (
    (Di < T3)*3 + (Di >= T3) * (
    4
    ))))))));

可能您可以进一步优化此代码,了解变量的可能内容。