我正在尝试在Python上实现Karatsuba乘法。 输入是长度为2的两个整数。它们的长度相同。
def mult(x,y):
if int(x) < 10 and int(y) <10:
return int(x)*int(y)
x_length = len(str(x))//2
y_length = len(str(y))//2
a = str(x)[:x_length]
b = str(x)[x_length:]
c = str(y)[:y_length]
d = str(y)[y_length:]
n = len(a) + len(b)
m = n//2
return 10**n* mult(a,c) + 10**m*(mult(a+b, c+d)-mult(a,c)-mult(b,d)) + mult(b,d)
跑步
mult(1234,5678)
这给出了以下错误:
if int(x) < 10 and int(y) <10:
RecursionError: maximum recursion depth exceeded while calling a Python object
如果我愿意
def mult(x,y):
if int(x) < 10 and int(y) <10:
return int(x)*int(y)
x_length = len(str(x))//2
y_length = len(str(y))//2
a = str(x)[:x_length]
b = str(x)[x_length:]
c = str(y)[:y_length]
d = str(y)[y_length:]
n = len(a) + len(b)
m = n//2
return 10**n* mult(a,c) + 10**m*(mult(a,d)+mult(b,c)) + mult(b,d)
因此,我在最后一行(即mult(a,c), mult(a,d), mult(b,c), mult(b,d)
)中进行了4次递归,而不是上面的3次(即mult(a,c), mult(a+b, c+d), mult(b,d)
)中的递归。
然后证明还可以。
为什么会这样?而且只有3个递归怎么办?
答案 0 :(得分:1)
a, b, c, d
是字符串。字符串加法是串联。 "1" + "2"
是"12"
。因此,传递给mult(a+b, c+d)
的不是您打算传递的。
TL; DR。
首先,递归应该很快终止。让我们看看为什么不这样。在print x, y
的开头添加mult
:
def mult(x, y):
print x, y
....
,然后将输出重定向到文件中。结果令人惊讶:
1234 5678
12 56
1 5
12 56
1 5
12 56
1 5
12 56
1 5
....
难怪堆栈溢出。问题是,为什么我们要重复12 56
的情况?让我们添加更多工具,以找出执行该操作的递归调用:
def mult(x,y,k=-1):
....
print a, b, c, d
ac = mult(a, c, 0)
bd = mult(b, d, 2)
return 10**n* ac + 10**m*(mult(a+b, c+d, 1) - ac - bd) + bd
结果是
-1 : 1234 5678
12 34 56 78
0 : 12 56
1 2 5 6
0 : 1 5
2 : 2 6
1 : 12 56
1 2 5 6
0 : 1 5
2 : 2 6
1 : 12 56
1 2 5 6
0 : 1 5
2 : 2 6
1 : 12 56
您可以看到标记为1
的递归调用总是得到12 56
。它是计算mult(a + b, c + d)
的调用。那好吧。它们a, b, c, d
都是字符串。 "1" + "2"
是"12"
。不完全是您的意思。
因此,请确定:参数是整数还是字符串,并对其进行相应处理。
答案 1 :(得分:0)
请注意,在您的第一个代码段中,您不是三次调用函数,而是调用了5次:
return 10**n* mult(a,c) + 10**m*(mult(a+b, c+d)-mult(a,c)-mult(b,d)) + mult(b,d)
对于其余的代码,我还不能说清楚,但是快速浏览一下Karatsuba上的Wikipedia条目,您可以通过增加所使用的基数来减少递归深度(即从10减少到100或1000 )。您可以使用sys.setrecursionlimit
来更改递归深度,但是python堆栈框架可能会变得很大,因此请避免这样做,因为这样做可能很危险。