我可以使用AVX FMA单元进行精确的52位整数乘法吗?

时间:2016-12-30 22:54:57

标签: floating-point x86 simd avx2 fma

AXV2没有任何整数乘法,其源大于32位。它确实提供了32 x 32 -> 32次乘法,以及32 x 32 -> 64乘以 1 ,但没有64位来源。

假设我需要一个大于32位但小于或等于52位的无符号乘法 - 我可以简单地使用浮点DP multiply或FMA指令,并且当整数输入和结果可以用52或更少的位表示时(即,在[0,2 ^ 52-1]范围内),输出是位精确的吗?

我想要所有104位产品的更一般情况怎么样?或者整数乘积超过52位的情况(即,产品在位索引中具有非零值> 52) - 但我只想要低52位?在后一种情况下,MUL将给我更高的位并舍去一些低位(也许是IFMA帮助的那些?)。

编辑:实际上,根据this answer,它可能会执行最多2 ^ 53的任何事情 - 我已经忘记了尾数前面隐含的前导1再给你一点。

1 有趣的是,64位产品PMULDQ操作具有一半的延迟和32位PMULLD版本的吞吐量的两倍,如Mysticial explains在评论中。

3 个答案:

答案 0 :(得分:13)

是的,这是可能的。但是从AVX2开始,它不太可能比使用MULX / ADCX / ADOX的标量方法更好。

对于不同的输入/输出域,这种方法实际上有无限多种变体。我只会覆盖其中的3个,但一旦你知道它们是如何工作的,它们很容易概括。

免责声明:

  • 此处的所有解决方案均假设舍入模式为round-to-even。
  • 建议不要使用快速数学优化标志,因为这些解决方案依赖严格的IEEE。

签名在范围内翻倍:[ - 2 51 ,2 51 ]

//  A*B = L + H*2^52
//  Input:  A and B are in the range [-2^51, 2^51]
//  Output: L and H are in the range [-2^51, 2^51]
void mul52_signed(__m256d& L, __m256d& H, __m256d A, __m256d B){
    const __m256d ROUND = _mm256_set1_pd(30423614405477505635920876929024.);    //  3 * 2^103
    const __m256d SCALE = _mm256_set1_pd(1. / 4503599627370496);                //  1 / 2^52

    //  Multiply and add normalization constant. This forces the multiply
    //  to be rounded to the correct number of bits.
    H = _mm256_fmadd_pd(A, B, ROUND);

    //  Undo the normalization.
    H = _mm256_sub_pd(H, ROUND);

    //  Recover the bottom half of the product.
    L = _mm256_fmsub_pd(A, B, H);

    //  Correct the scaling of H.
    H = _mm256_mul_pd(H, SCALE);
}

这是最简单的一种,也是唯一一种与标量方法竞争的产品。最终缩放是可选的,具体取决于您要对输出执行的操作。所以这只能被认为是3条指令。但它也是最不实用的,因为输入和输出都是浮点值。

两个FMA保持融合绝对至关重要。这就是快速数学优化可以破坏事物的地方。如果第一个FMA被分解,则L不再保证在[-2^51, 2^51]范围内。如果第二个FMA被解散,L将完全错误。

范围内的有符号整数:[ - 2 51 ,2 51 ]

//  A*B = L + H*2^52
//  Input:  A and B are in the range [-2^51, 2^51]
//  Output: L and H are in the range [-2^51, 2^51]
void mul52_signed(__m256i& L, __m256i& H, __m256i A, __m256i B){
    const __m256d CONVERT_U = _mm256_set1_pd(6755399441055744);     //  3*2^51
    const __m256d CONVERT_D = _mm256_set1_pd(1.5);

    __m256d l, h, a, b;

    //  Convert to double
    A = _mm256_add_epi64(A, _mm256_castpd_si256(CONVERT_U));
    B = _mm256_add_epi64(B, _mm256_castpd_si256(CONVERT_D));
    a = _mm256_sub_pd(_mm256_castsi256_pd(A), CONVERT_U);
    b = _mm256_sub_pd(_mm256_castsi256_pd(B), CONVERT_D);

    //  Get top half. Convert H to int64.
    h = _mm256_fmadd_pd(a, b, CONVERT_U);
    H = _mm256_sub_epi64(_mm256_castpd_si256(h), _mm256_castpd_si256(CONVERT_U));

    //  Undo the normalization.
    h = _mm256_sub_pd(h, CONVERT_U);

    //  Recover bottom half.
    l = _mm256_fmsub_pd(a, b, h);

    //  Convert L to int64
    l = _mm256_add_pd(l, CONVERT_D);
    L = _mm256_sub_epi64(_mm256_castpd_si256(l), _mm256_castpd_si256(CONVERT_D));
}

建立第一个例子,我们将它与fast double <-> int64 conversion trick的通用版本结合起来。

这个更有用,因为你正在使用整数。但即使使用快速转换技巧,大部分时间都将用于转换。幸运的是,如果多次乘以相同的操作数,则可以消除一些输入转换。

范围内的无符号整数:[0,2 52

//  A*B = L + H*2^52
//  Input:  A and B are in the range [0, 2^52)
//  Output: L and H are in the range [0, 2^52)
void mul52_unsigned(__m256i& L, __m256i& H, __m256i A, __m256i B){
    const __m256d CONVERT_U = _mm256_set1_pd(4503599627370496);     //  2^52
    const __m256d CONVERT_D = _mm256_set1_pd(1);
    const __m256d CONVERT_S = _mm256_set1_pd(1.5);

    __m256d l, h, a, b;

    //  Convert to double
    A = _mm256_or_si256(A, _mm256_castpd_si256(CONVERT_U));
    B = _mm256_or_si256(B, _mm256_castpd_si256(CONVERT_D));
    a = _mm256_sub_pd(_mm256_castsi256_pd(A), CONVERT_U);
    b = _mm256_sub_pd(_mm256_castsi256_pd(B), CONVERT_D);

    //  Get top half. Convert H to int64.
    h = _mm256_fmadd_pd(a, b, CONVERT_U);
    H = _mm256_xor_si256(_mm256_castpd_si256(h), _mm256_castpd_si256(CONVERT_U));

    //  Undo the normalization.
    h = _mm256_sub_pd(h, CONVERT_U);

    //  Recover bottom half.
    l = _mm256_fmsub_pd(a, b, h);

    //  Convert L to int64
    l = _mm256_add_pd(l, CONVERT_S);
    L = _mm256_sub_epi64(_mm256_castpd_si256(l), _mm256_castpd_si256(CONVERT_S));

    //  Make Correction
    H = _mm256_sub_epi64(H, _mm256_srli_epi64(L, 63));
    L = _mm256_and_si256(L, _mm256_set1_epi64x(0x000fffffffffffff));
}

最后,我们得到原始问题的答案。这通过调整转换并添加更正步骤来构建有符号整数解决方案。

但是在这一点上,我们处理了13条指令 - 其中一半是高延迟指令,不包括大量FP <-> int旁路延迟。所以这不太可能赢得任何基准。相比之下,64 x 64 -> 128-bit SIMD乘法可以在16条指令中完成(如果预处理输入则为14条)。

如果舍入模式是向下舍入或舍入为零,则可以省略校正步骤。唯一重要的指示是h = _mm256_fmadd_pd(a, b, CONVERT_U);。因此,在AVX512上,您可以覆盖该指令的舍入并单独保留舍入模式。

最后的想法:

值得注意的是,通过调整魔术常数可以减少2 52 的操作范围。这可能对第一个解决方案(浮点数)有用,因为它为您提供额外的尾数用于浮点累积。这样就可以避免在最后两个解决方案中不断地在int64和double之间来回转换。

虽然这里的3个示例不太可能比标量方法更好,但AVX512几乎肯定会取得平衡。特别是Knights Landing的ADCX和ADOX吞吐量很低。

当然,当AVX512-IFMA问世时,所有这一切都没有实际意义。这会将完整的52 x 52 -> 104-bit产品减少为2条指令,并免费提供积累。

答案 1 :(得分:3)

进行多字整数运算的一种方法是使用double-double arithmetic。让我们从一些双倍乘法代码开始

#include <math.h>
typedef struct {
  double hi;
  double lo;
} doubledouble;

static doubledouble quick_two_sum(double a, double b) {
  double s = a + b;
  double e = b - (s - a);
  return (doubledouble){s, e};
}

static doubledouble two_prod(double a, double b) {
  double p = a*b;
  double e = fma(a, b, -p);
  return (doubledouble){p, e};
}

doubledouble df64_mul(doubledouble a, doubledouble b) {
  doubledouble p = two_prod(a.hi, b.hi);
  p.lo += a.hi*b.lo;
  p.lo += a.lo*b.hi;
  return quick_two_sum(p.hi, p.lo);
}

函数two_prod可以执行整数53bx53b - &gt; 106b在两个指令中。函数df64_mul可以做整数106bx106b - &gt; 106B。

让我们将它与整数128bx128b进行比较 - &gt; 128b,带整数硬件。

__int128 mul128(__int128 a, __int128 b) {
  return a*b;
}

mul128

的程序集
imul    rsi, rdx
mov     rax, rdi
imul    rcx, rdi
mul     rdx
add     rcx, rsi
add     rdx, rcx

df64_mul的汇编(使用gcc -O3 -S i128.c -masm=intel -mfma -ffp-contract=off编译)

vmulsd      xmm4, xmm0, xmm2
vmulsd      xmm3, xmm0, xmm3
vmulsd      xmm1, xmm2, xmm1
vfmsub132sd xmm0, xmm4, xmm2
vaddsd      xmm3, xmm3, xmm0
vaddsd      xmm1, xmm3, xmm1
vaddsd      xmm0, xmm1, xmm4
vsubsd      xmm4, xmm0, xmm4
vsubsd      xmm1, xmm1, xmm4

mul128进行三次标量乘法和两次标量加法/减法,而df64_mul进行3次SIMD乘法,1次SIMD FMA和5次SIMD加法/减法。我没有对这些方法进行分析,但对我来说,df64_mul每个AVX寄存器使用4-double可以胜过mul128并且sd更改为pd和{{xmm,这似乎不合理。 1}}到ymm)。

很容易说问题是切换回整数域。但为什么这有必要呢?您可以在浮点域中执行所有操作。我们来看一些例子。我发现使用float进行单元测试比使用double进行单元测试更容易。

doublefloat two_prod(float a, float b) {
  float p = a*b;
  float e = fma(a, b, -p);
  return (doublefloat){p, e};
}

//3202129*4807935=15395628093615
x = two_prod(3202129,4807935)
int64_t hi = p, lo = e, s = hi+lo
//p = 1.53956280e+13, e = 1.02575000e+05  
//hi = 15395627991040, lo = 102575, s = 15395628093615

//1450779*1501672=2178594202488
y = two_prod(1450779, 1501672)
int64_t hi = p, lo = e, s = hi+lo 
//p = 2.17859424e+12, e = -4.00720000e+04
//hi = 2178594242560 lo = -40072, s = 2178594202488

所以我们最终得到不同的范围,在第二种情况下,错误(e)甚至是负数,但总和仍然是正确的。我们甚至可以将两个双浮点值xy一起添加(一旦我们知道如何进行双重加法 - 请参阅最后的代码)并获取15395628093615+2178594202488。没有必要对结果进行标准化。

但是加法带来了双重算术的主要问题。即,加法/减法很慢,例如128b + 128b - &gt; 128b needs at least 11 floating point additions而整数只需要两个(addadc)。

因此,如果一个算法在乘法上很重,但在加法时很轻,那么用双倍的多字整数运算就可以获胜。

作为旁注,C语言足够灵活,允许实现完全通过浮点硬件实现整数的实现。 int可以是24位(来自单个浮点),long可以是54位。 (来自双浮点),long long可以是106位(来自双倍)。 C甚至不需要两个恭维,因此整数可以使用带符号的幅度来表示负数,就像浮点数一样。

以下是使用双倍乘法和加法的C代码(我没有实现除法或其他操作,如sqrt,但有文件显示如何执行此操作)以防有人想要使用它。看看这是否可以针对整数进行优化会很有趣。

//if compiling with -mfma you must also use -ffp-contract=off
//float-float is easier to debug. If you want double-double replace
//all float words with double and fmaf with fma 
#include <stdio.h>
#include <math.h>
#include <inttypes.h>
#include <x86intrin.h>
#include <stdlib.h>

//#include <float.h>

typedef struct {
  float hi;
  float lo;
} doublefloat;

typedef union {
  float f;
  int i;
  struct {
    unsigned mantisa : 23;
    unsigned exponent: 8;
    unsigned sign: 1;
  };
} float_cast;

void print_float(float_cast a) {
  printf("%.8e, 0x%x, mantisa 0x%x, exponent 0x%x, expondent-127 %d, sign %u\n", a.f, a.i, a.mantisa, a.exponent, a.exponent-127, a.sign);
}

void print_doublefloat(doublefloat a) {
  float_cast hi = {a.hi};
  float_cast lo = {a.lo};
  printf("hi: "); print_float(hi);
  printf("lo: "); print_float(lo);
}

doublefloat quick_two_sum(float a, float b) {
  float s = a + b;
  float e = b - (s - a);
  return (doublefloat){s, e};
  // 3 add
}

doublefloat two_sum(float a, float b) {
  float s = a + b;
  float v = s - a;
  float e = (a - (s - v)) + (b - v);
  return (doublefloat){s, e};
  // 6 add 
}

doublefloat df64_add(doublefloat a, doublefloat b) {
  doublefloat s, t;
  s = two_sum(a.hi, b.hi);
  t = two_sum(a.lo, b.lo);
  s.lo += t.hi;
  s = quick_two_sum(s.hi, s.lo);
  s.lo += t.lo;
  s = quick_two_sum(s.hi, s.lo);
  return s;
  // 2*two_sum, 2 add, 2*quick_two_sum = 2*6 + 2 + 2*3 = 20 add
}

doublefloat split(float a) {
  //#define SPLITTER (1<<27) + 1
#define SPLITTER (1<<12) + 1
  float t = (SPLITTER)*a;
  float hi = t - (t - a);
  float lo = a - hi;
  return (doublefloat){hi, lo};
  // 1 mul, 3 add
}

doublefloat split_sse(float a) {
  __m128 k = _mm_set1_ps(4097.0f);
  __m128 a4 = _mm_set1_ps(a);
  __m128 t = _mm_mul_ps(k,a4);
  __m128 hi4 = _mm_sub_ps(t,_mm_sub_ps(t, a4));
  __m128 lo4 = _mm_sub_ps(a4, hi4);
  float tmp[4];
  _mm_storeu_ps(tmp, hi4);
  float hi = tmp[0];
  _mm_storeu_ps(tmp, lo4);
  float lo = tmp[0];
  return (doublefloat){hi,lo};

}

float mult_sub(float a, float b, float c) {
  doublefloat as = split(a), bs = split(b);
  //print_doublefloat(as);
  //print_doublefloat(bs);
  return ((as.hi*bs.hi - c) + as.hi*bs.lo + as.lo*bs.hi) + as.lo*bs.lo;
  // 4 mul, 4 add, 2 split = 6 mul, 10 add
}

doublefloat two_prod(float a, float b) {
  float p = a*b;
  float e = mult_sub(a, b, p);
  return (doublefloat){p, e};
  // 1 mul, one mult_sub
  // 7 mul, 10 add
}

float mult_sub2(float a, float b, float c) {
  doublefloat as = split(a);
  return ((as.hi*as.hi -c ) + 2*as.hi*as.lo) + as.lo*as.lo;
}

doublefloat two_sqr(float a) {
  float p = a*a;
  float e = mult_sub2(a, a, p);
  return (doublefloat){p, e};
}

doublefloat df64_mul(doublefloat a, doublefloat b) {
  doublefloat p = two_prod(a.hi, b.hi);
  p.lo += a.hi*b.lo;
  p.lo += a.lo*b.hi;
  return quick_two_sum(p.hi, p.lo);
  //two_prod, 2 add, 2mul, 1 quick_two_sum = 9 mul, 15 add 
  //or 1 mul, 1 fma, 2add 2mul, 1 quick_two_sum = 3 mul, 1 fma, 5 add
}

doublefloat df64_sqr(doublefloat a) {
  doublefloat p = two_sqr(a.hi);
  p.lo += 2*a.hi*a.lo;
  return quick_two_sum(p.hi, p.lo);
}

int float2int(float a) {
  int M = 0xc00000; //1100 0000 0000 0000 0000 0000
  a += M;
  float_cast x;
  x.f = a;
  return x.i - 0x4b400000;
}

doublefloat add22(doublefloat a, doublefloat b) {
  float r = a.hi + b.hi;
  float s = fabsf(a.hi) > fabsf(b.hi) ?
    (((a.hi - r) + b.hi) + b.lo ) + a.lo :
    (((b.hi - r) + a.hi) + a.lo ) + b.lo;
  return two_sum(r, s);  
  //11 add 
}

int main(void) {
  //print_float((float_cast){1.0f});
  //print_float((float_cast){-2.0f});
  //print_float((float_cast){0.0f});
  //print_float((float_cast){3.14159f});
  //print_float((float_cast){1.5f});
  //print_float((float_cast){3.0f});
  //print_float((float_cast){7.0f});
  //print_float((float_cast){15.0f});
  //print_float((float_cast){31.0f});

  //uint64_t t = 0xffffff;
  //print_float((float_cast){1.0f*t});
  //printf("%" PRId64 " %" PRIx64 "\n", t*t,t*t);

  /*
    float_cast t1;
    t1.mantisa = 0x7fffff;
    t1.exponent = 0xfe;
    t1.sign = 0;
    print_float(t1);
  */
  //doublefloat z = two_prod(1.0f*t, 1.0f*t);
  //print_doublefloat(z);
  //double z2 = (double)z.hi + (double)z.lo;
  //printf("%.16e\n", z2);
  doublefloat s = {0};
  int64_t si = 0;
  for(int i=0; i<100000; i++) {
    int ai = rand()%0x800, bi = rand()%0x800000;
    float a = ai, b = bi;
    doublefloat z = two_prod(a,b);
    int64_t zi = (int64_t)ai*bi;
    //print_doublefloat(z);
    //s = df64_add(s,z);
    s = add22(s,z);
    si += zi;
    print_doublefloat(z);
    printf("%d %d ", ai,bi);
    int64_t h = z.hi;
    int64_t l = z.lo;
    int64_t t = h+l;
    //if(t != zi) printf("%" PRId64 " %" PRId64 "\n", h, l);

    printf("%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "\n", zi, h, l, h+l);

    h = s.hi;
    l = s.lo;
    t = h + l;
    //if(si != t) printf("%" PRId64 " %" PRId64 "\n", h, l);

    if(si > (1LL<<48)) {
      printf("overflow after %d iterations\n", i); break;
    }
  }

  print_doublefloat(s);
  printf("%" PRId64 "\n", si);
  int64_t x = s.hi;
  int64_t y = s.lo;
  int64_t z = x+y;
  //int hi = float2int(s.hi);
  printf("%" PRId64 " %" PRId64 " %" PRId64 "\n", z,x,y);
}

答案 2 :(得分:2)

嗯,你肯定可以对整数事物进行FP-lane操作。并且它们总是准确的:虽然有SSE指令不保证正确的IEEE-754精度和舍入,但毫无例外它们是没有整数范围的那些,所以不是你正在看的那些。底线:加法/减法/乘法在整数域中始终是精确的,即使您在打包浮点数上执行它们也是如此。

对于四精度浮点数(> 52位尾数),不,这些不受支持,可能在可预见的未来。只是没有多少要求他们。他们出现在一些SPARC时代的工作站架构中,但说实话,他们只是开发人员对如何编写数值稳定算法的不完全理解的绷带,随着时间的推移它们逐渐淡出。

宽整数运算对SSE来说非常不合适。我最近在实现一个大整数库时真的试图利用它,说实话,这对我没有好处。 x86是设计的用于多字算术;你可以在诸如ADC(产生和消耗进位)和IDIV(它允许除数是被除数的两倍宽)的操作中看到它,只要商不宽于被除数,这是一个约束使得它对于任何多字分割都没用。但是多字算术本质上是顺序的,而SSE本质上是平行的。如果你足够幸运,你的数字足够位以适应FP尾数,恭喜你。但如果你有一般的大整数,SSE可能不会成为你的朋友。