考虑下面的函数,它将a * b的结果转换成几个数字i和j,其中:
如果不使用任何大于int的类型,你将如何计算j (如果小心溢出int的是UB):
#include <iostream>
#include <cstdlib>
using namespace std;
int n = 1000, m = 100000;
struct N {
int i, j;
};
N f(int a, int b) {
N x;
int a0, a1, b0, b1, o;
a1 = a / n;
a0 = a - (a1 * n); // a0 = a % n
b1 = b / m;
b0 = b - (b1 * m); // b0 = b % m
o = a1 * b1 + (a0 * b1) / n + (b0 * a1) / m;
x.i = o;
x.j = 0; // CALCULATE J WITH INTs MATH
return x;
}
int main(int, char* argv[]) {
int a = atoi(argv[1]),
b = atoi(argv[2]);
N x = f(a, b);
cout << a << " * " << b << " = " << x.i << "*" << n*m
<< " + " << x.j << endl;
cout << "which is: " << (long long)a * b << endl;
return 0;
}
答案 0 :(得分:2)
您开始正确,但在计算o
时丢失了图表。首先,我的假设:你不想处理任何大于n*m
的整数,所以取mod n*m
作弊。我这样说,因为给定m > 2^16
,我必须假设int是32位长,它能够处理你的数字而不会溢出。
无论如何。你有正确的(我猜,因为没有指定n
和m
的目的)写的:
a=a0 + a1*n (a0<n)
b=b0 + b1*m (b0<m)
所以,如果我们做数学:
a*b = a0*b0 + a0*b1*m + a1*b0*n + a1*b1*n*m
此处a0*b0 < n*m
,因此它是j
和a1*b1*n*m > n*m
的一部分,因此它是i
的一部分。另外两个术语需要再分为两个。但是你无法计算每一个并取mod n*m
,因为这会欺骗(根据我的规则)。如果你写:
a0*b1 = a0b1_0 + a0b1_1*n
你得到:
a0*b1*m = a0b1_0*m + a0b1_1*n*m
自a0b1_0 < n
,a0b1_0*m < n*m
以来,这意味着此部分转到j
。显然,a0b1_1
会转到i.
为a1 * b0重复一个类似的逻辑,你有三个词来加起来j
,还有三个加起来为i
。
您需要约束a < n^2
和b < m^2
才能生效。否则,您需要更多 i “words”。例如:a = a0 + a1*n + a2*n^2, ai < n
。
j
的最终总和可能大于n*m
。您需要注意溢出(n*m - o < addend
或类似逻辑,并在发生这种情况时将1
添加到i
- 同时计算j + addend - n*m
而不会溢出。)
答案 1 :(得分:1)
我认为答案是j = a0 * b0
(a*b)/(n*m) = (a/n) * (b/m)
= (a1 + a0/n) * (b1 + b0/m)
= a1*b1 + a1*b0/m + a0*b1/n + (a0*b0)/(n*m)
现在
o = a1*b1 + a1*b0/m + a0*b1/n
将两侧乘以n * m
a * b = o * n*m + a0*b0
n * m是基础
a * b = o * BASE + a0*b0
j = a0*b0
QED