我无法理解为什么这个功能有效?有人能解释一下它一步一步做什么吗?我知道这个想法是如果n是偶数,则^ n等于(a ^(n / 2))^ 2或者如果n是奇数,则a(a ^((n-1)/ 2))^ 2,但是这个功能是如何做到的?
double pow(double a, int n) {
double ret = 1;
while(n) {
if(n%2 == 1) ret *= a;
a *= a; n /= 2;
}
return ret;
}
答案 0 :(得分:3)
本计划使用的平等点如下:
a^n * ret
是结果。事实上,在开头ret
是1
,而在循环结束时n == 0
,因此a^0 * ret
是结果,自a^0 == 1
以来, ret
是预期的结果。n
为奇数(即n%2 == 1
),则存在b≥0
,n=b*2+1
。在这种情况下,我们使用以下等式:a^(b*2+1)=(a^(b*2))*a
。因此ret
乘以a
。 a^(b*2) = (a^2)^b
,以便a
与自身相乘,n
除以2,最终保持不变量。< / LI>
醇>
请注意,在循环内部,整数除法用于n /= 2
,因此在两种情况下结果始终为b
(n
奇数,即n=b*2+1
,或n
是偶数,即n=b*2
)。
最后请注意,正如@chux在评论中指出的那样,该函数无法正确管理n
的负值。
答案 1 :(得分:1)
这是我的Python递归代码,它是IMO更易读和可理解的(我知道在Python中创建递归函数并不是一个好主意,但我之所以选择Python是因为其语法简单来演示这个想法)。
def pow(n, e):
if e == 0:
return 1
if e % 2 == 1:
return n * pow(n, e - 1)
# this step makes the algorithm to run in O(lg n) time
tmp = pow(n, e / 2)
return tmp * tmp
我将再次强调,tmp = pow(n, e / 2)
是降低时间复杂度的线。
算法不是将数乘以e乘以n,而是重用一些先前计算的结果。例如,2 ^ 8将被计算为2 ^ 4 * 2 ^ 4。这里2 ^ 4将只计算一次,并且将以这种方式跳过一半的迭代。同样适用于2 ^ 4等。
我试图以某种方式更直观地解释它,而没有深入研究这种优化背后的理论。如果你想更深入地理解它以及它在位级上的工作原理,那么这是一个很好的tutorial
答案 2 :(得分:1)
我将从一些更明显的代码开始:
double pow(double a, int n) {
int k = 0, m = 1, n2 = n;
double pow_k = 1.0, pow_m = a;
assert (n2 * m + k == n);
while (n2 != 0) {
if (n2 % 2 != 0) { k += m; pow_k *= pow_m; n2 -= 1; }
assert (n2 * m + k == n); assert (n2 % 2 == 0);
m = m * 2; pow_m = pow_m * pow_m; n2 /= 2;
assert (n2 * m + k == n);
}
return pow_k;
}
在循环中的每个点,pow_k = a ^ k和pow_m = a ^ m。 n2 * m + k == n始终为真。当n2 == n,m == 1,k == 0时,它最初为真。
在循环中的第一个if语句之前,n2是偶数,因此断言保持为真且n2保持偶数。或者n2是奇数。在那种情况下,n2减少1,使n2 * m减小m; k增加m,使n2 * m + k保持不变。并且n2是均匀的。
然后m加倍并且n2正好减半,因为n2是偶数,再次保持n2 * m + k不变。
由于在每次迭代中n2除以2,因此n2最终变为0,因此循环结束。具有n2 == 0的断言意味着0 * m + k == n或k == n,因此pow_k = a ^ k = a ^ n。因此返回的结果是^ n。
现在我们省略了k,m和断言,它们没有改变计算:
double pow(double a, int n) {
int n2 = n;
double pow_k = 1.0, pow_m = a;
while (n2 != 0) {
if (n2 % 2 != 0) { pow_k *= pow_m; n2 -= 1; }
m = m * 2; pow_m = pow_m * pow_m; n2 /= 2;
}
return pow_k;
}
当n2为奇数时,我们可以删除n2 - = 1,因为在除以2之后它不会产生差异。由于没有使用n,我们可以使用n而不是n2:
double pow(double a, int n) {
double pow_k = 1.0, pow_m = a;
while (n != 0) {
if (n % 2 != 0) pow_k *= pow_m;
pow_m = pow_m * pow_m; n /= 2;
}
return pow_k;
}
现在我们将pow_k更改为ret,将pow_m更改为a,并将n%2!= 0更改为n%2 == 1,我们将获得原始代码:
double pow(double a, int n) {
double ret = 1.0;
while (n != 0) {
if (n % 2 == 1) ret *= a;
a *= a; n /= 2;
}
return ret;
}