使用SSE(x * x * x)+(y * y * y)进行乘法

时间:2011-12-02 13:39:19

标签: c x86 sse simd

我尝试使用SIMD优化此功能,但我不知道从哪里开始。

long sum(int x,int y)
{
    return x*x*x+y*y*y;
}

反汇编函数如下所示:

  4007a0:   48 89 f2                mov    %rsi,%rdx
  4007a3:   48 89 f8                mov    %rdi,%rax
  4007a6:   48 0f af d6             imul   %rsi,%rdx
  4007aa:   48 0f af c7             imul   %rdi,%rax
  4007ae:   48 0f af d6             imul   %rsi,%rdx
  4007b2:   48 0f af c7             imul   %rdi,%rax
  4007b6:   48 8d 04 02             lea    (%rdx,%rax,1),%rax
  4007ba:   c3                      retq   
  4007bb:   0f 1f 44 00 00          nopl   0x0(%rax,%rax,1)

调用代码如下所示:

 do {
for (i = 0; i < maxi; i++) {
  j = nextj[i];
  long sum = cubeSum(i,j);
  while (sum <= p) {
    long x = sum & (psize - 1);
    int flag = table[x];
    if (flag <= guard) {
      table[x] = guard+1;
    } else if (flag == guard+1) {
      table[x] = guard+2;
      count++;
    }
    j++;
    sum = cubeSum(i,j);
  }
  nextj[i] = j;
}
p += psize;
guard += 3;
} while (p <= n);

2 个答案:

答案 0 :(得分:6)

  • 用(x | y | 0 | 0)填充一个SSE寄存器(因为每个SSE寄存器包含4个32位元素)。让我们称之为r1
  • 然后将该寄存器的副本复制到另一个寄存器r2
  • 执行r2 * r1,将结果存储在r2。
  • r2 * r1再次将结果存储在r2
  • 现在在r2你有(x * x * x | y * y * y | 0 | 0)
  • 将r2的下两个元素打包成单独的寄存器,添加它们(SSE3有水平添加指令,但仅适用于浮点数和双精度数。)

最后,如果结果比编译器为您生成的简单代码更快,我会感到惊讶。如果您有要操作的数据数组,SIMD会更有用。

答案 1 :(得分:1)

这种特殊情况不适合SIMD(SSE或其他)。当你有连续的数组可以顺序访问并且处理异构时,SIMD真的只能运行良好。

但是,您至少可以摆脱标量代码中的一些冗余操作,例如:当i * i * i不变时重复计算i

do {
    for (i = 0; i < maxi; i++) {
        int i3 = i * i * i;
        int j = nextj[i];
        int j3 = j * j * j;
        long sum = i3 + j3;
        while (sum <= p) {
            long x = sum & (psize - 1);
            int flag = table[x];
            if (flag <= guard) {
              table[x] = guard+1;
            } else if (flag == guard+1) {
              table[x] = guard+2;
              count++;
            }
            j++;
            j3 = j * j * j;
            sum = i3 + j3;
        }
        nextj[i] = j;
    }
    p += psize;
    guard += 3;
} while (p <= n);