你如何计算这个函数中的j?

时间:2011-06-21 10:57:52

标签: c++ algorithm math

考虑下面的函数,它将a * b的结果转换成几个数字i和j,其中:

  1. a,b,x,y为int(假设它们总是=> 32位长)
  2. a和b是< = n * m,其中n = 10 ^ 3且m = 10 ^ 5。 n * m = BASE。
  3. a * b可以写成i * BASE + j
  4. 如果不使用任何大于int的类型,你将如何计算j (如果小心溢出int的是UB):

    #include <iostream>
    #include <cstdlib>
    
    using namespace std;
    
    int n = 1000, m = 100000;
    
    struct N {
            int i, j;
    };
    
    N f(int a, int b) {
            N x;
            int a0, a1, b0, b1, o;
            a1 = a / n;
            a0 = a - (a1 * n); // a0 = a % n
            b1 = b / m;
            b0 = b - (b1 * m);  // b0 = b % m
            o = a1 * b1 + (a0 * b1) / n + (b0 * a1) / m;
            x.i = o;
            x.j = 0; // CALCULATE J WITH INTs MATH
            return x;
    }
    
    int main(int, char* argv[]) {
            int a = atoi(argv[1]),
            b = atoi(argv[2]);
            N x = f(a, b);
            cout << a << " * " << b << " = " << x.i << "*" << n*m 
                 << " + " << x.j << endl;
            cout << "which is: " << (long long)a * b << endl;
            return 0;
    }
    

2 个答案:

答案 0 :(得分:2)

您开始正确,但在计算o时丢失了图表。首先,我的假设:你不想处理任何大于n*m的整数,所以取mod n*m作弊。我这样说,因为给定m > 2^16,我必须假设int是32位长,它能够处理你的数字而不会溢出。

无论如何。你有正确的(我猜,因为没有指定nm的目的)写的:

a=a0 + a1*n (a0<n)
b=b0 + b1*m (b0<m)

所以,如果我们做数学:

a*b = a0*b0 + a0*b1*m + a1*b0*n + a1*b1*n*m

此处a0*b0 < n*m,因此它是ja1*b1*n*m > n*m的一部分,因此它是i的一部分。另外两个术语需要再分为两个。但是你无法计算每一个并取mod n*m,因为这会欺骗(根据我的规则)。如果你写:

a0*b1 = a0b1_0 + a0b1_1*n

你得到:

a0*b1*m = a0b1_0*m + a0b1_1*n*m

a0b1_0 < na0b1_0*m < n*m以来,这意味着此部分转到j。显然,a0b1_1会转到i.

为a1 * b0重复一个类似的逻辑,你有三个词来加起来j,还有三个加起来为i


编辑:忘了提几件事:

  • 您需要约束a < n^2b < m^2才能生效。否则,您需要更多 i “words”。例如:a = a0 + a1*n + a2*n^2, ai < n

  • j的最终总和可能大于n*m。您需要注意溢出(n*m - o < addend或类似逻辑,并在发生这种情况时将1添加到i - 同时计算j + addend - n*m而不会溢出。)

答案 1 :(得分:1)

我认为答案是j = a0 * b0

(a*b)/(n*m) = (a/n) * (b/m)
            = (a1 + a0/n) * (b1 + b0/m)
            = a1*b1 + a1*b0/m + a0*b1/n + (a0*b0)/(n*m)

现在

o = a1*b1 + a1*b0/m + a0*b1/n

将两侧乘以n * m

a * b  = o * n*m  +  a0*b0

n * m是基础

a * b  = o * BASE  +  a0*b0

j = a0*b0

QED