Karatsuba算法过多的递归

时间:2011-08-14 18:34:29

标签: python algorithm math recursion multiplication

我正在尝试用c ++实现Karatsuba乘法算法但是现在我只是想让它在python中运行。

这是我的代码:

def mult(x, y, b, m):
    if max(x, y) < b:
        return x * y

    bm = pow(b, m)
    x0 = x / bm
    x1 = x % bm
    y0 = y / bm
    y1 = y % bm

    z2 = mult(x1, y1, b, m)
    z0 = mult(x0, y0, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0

我得不到的是:应该如何创建z2z1z0?使用mult函数递归正确吗?如果是这样,我在某处乱搞,因为递归并没有停止。

有人可以指出错误的位置吗?

5 个答案:

答案 0 :(得分:5)

  

注意:以下回复直接解决了OP的问题   过度递归,但它并不试图提供正确的   Karatsuba算法。其他回复的信息量更多   这方面。

试试这个版本:

def mult(x, y, b, m):
    bm = pow(b, m)

    if min(x, y) <= bm:
        return x * y

    # NOTE the following 4 lines
    x0 = x % bm
    x1 = x / bm
    y0 = y % bm
    y1 = y / bm

    z0 = mult(x0, y0, b, m)
    z2 = mult(x1, y1, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0
    assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval)
    return retval

您的版本最严重的问题是您的x0和x1以及y0和y1的计算都会被翻转。此外,如果x1y1为0,则算法的推导不成立,因为在这种情况下,因子分解步骤变为无效。因此,您必须通过确保x和y都大于b ** m来避免这种可能性。

编辑:修改了代码中的拼写错误;补充说明

EDIT2:

更清楚,直接评论原始版本:

def mult(x, y, b, m):
    # The termination condition will never be true when the recursive 
    # call is either
    #    mult(z2, bm ** 2, b, m)
    # or mult(z1, bm, b, m)
    #
    # Since every recursive call leads to one of the above, you have an
    # infinite recursion condition.
    if max(x, y) < b:
        return x * y

    bm = pow(b, m)

    # Even without the recursion problem, the next four lines are wrong
    x0 = x / bm  # RHS should be x % bm
    x1 = x % bm  # RHS should be x / bm
    y0 = y / bm  # RHS should be y % bm
    y1 = y % bm  # RHS should be y / bm

    z2 = mult(x1, y1, b, m)
    z0 = mult(x0, y0, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0

答案 1 :(得分:4)

通常大数字存储为整数数组。每个整数代表一位数。这种方法允许将任意数乘以基数的幂乘以数组的简单左移。

这是我基于列表的实现(可能包含错误):

def normalize(l,b):
    over = 0
    for i,x in enumerate(l):
        over,l[i] = divmod(x+over,b)
    if over: l.append(over)
    return l
def sum_lists(x,y,b):
    l = min(len(x),len(y))
    res = map(operator.add,x[:l],y[:l])
    if len(x) > l: res.extend(x[l:])
    else: res.extend(y[l:])
    return normalize(res,b)
def sub_lists(x,y,b):
    res = map(operator.sub,x[:len(y)],y)
    res.extend(x[len(y):])
    return normalize(res,b)
def lshift(x,n):
    if len(x) > 1 or len(x) == 1 and x[0] != 0:
        return [0 for i in range(n)] + x
    else: return x
def mult_lists(x,y,b):
    if min(len(x),len(y)) == 0: return [0]
    m = max(len(x),len(y))
    if (m == 1): return normalize([x[0]*y[0]],b)
    else: m >>= 1
    x0,x1 = x[:m],x[m:]
    y0,y1 = y[:m],y[m:]
    z0 = mult_lists(x0,y0,b)
    z1 = mult_lists(x1,y1,b)
    z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b)
    t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m)
    t2 = lshift(z1,m*2)
    return sum_lists(sum_lists(z0,t1,b),t2,b)

sum_listssub_lists返回非标准化结果 - 单个数字可能大于基值。 normalize函数解决了这个问题。

所有函数都希望以相反的顺序获取数字列表。例如,基数10中的12应写为[2,1]。让我们取一个9987654321的正方形。

» a = [1,2,3,4,5,6,7,8,9]
» res = mult_lists(a,a,10)
» res.reverse()
» res
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1]

答案 2 :(得分:4)

Karatsuba乘法的目标是通过3次递归调用而不是4次递增调用来改进分而治之的乘法算法。因此,脚本中应包含对乘法的递归调用的唯一行是分配z0z1z2 的行。任何其他东西都会让你更加复杂。当你还没有定义乘法(以及更宽的取幂)时,你不能使用pow来计算 b m

为此,该算法至关重要地使用了它使用位置记法系统的事实。如果你在 b 中有一个数字的表示 x ,那么 x * b m 只是通过将该表示的数字m次移到左侧。对于任何位置符号系统,该移位操作基本上是“自由的”。这也意味着如果你想实现它,你必须重现这种位置表示法和“免费”转换。你选择在base b = 2 中计算并使用python的位运算符(或者如果你的测试平台有它们,则使用给定小数,十六进制,... base的位运算符),或者您决定为教育目的实现适用于任意 b 的东西,并使用字符串,数组或列表等方式重现此位置算术。

您已经有solution with lists了。我喜欢在python中使用字符串,因为int(s, base)将为您提供对应于字符串s的整数,该字符串被视为基数base中的数字表示:它使测试变得容易。 我发布了一个评论很多的基于字符串的实现作为要点here ,包括字符串到数字和数字到字符串的原语,以便进行衡量。

你可以通过提供带有基数的填充字符串及其(等于)长度作为mult的参数来测试它:

In [169]: mult("987654321","987654321",10,9)

Out[169]: '966551847789971041'

如果您不想弄清楚填充或计数字符串长度,填充函数可以为您完成:

In [170]: padding("987654321","2")

Out[170]: ('987654321', '000000002', 9)

当然,它适用于b>10

In [171]: mult('987654321', '000000002', 16, 9)

Out[171]: '130eca8642'

(选中wolfram alpha

答案 3 :(得分:1)

我认为这种技术背后的想法是使用递归算法计算z i 项,但结果不是那样统一的。由于您想要的最终结果是

z0 B^2m + z1 B^m + z2

假设您选择了合适的B值(比如2),您可以在不进行任何乘法的情况下计算B ^ m。例如,当使用B = 2时,您可以使用位移而不是乘法来计算B ^ m。这意味着最后一步可以在不进行任何乘法的情况下完成。

还有一件事 - 我注意到你为整个算法选择了一个固定的m值。通常,你可以通过使m总是一个值来实现这个算法,使得当它们用基数B写入时,B ^ m是x和y中数字的一半。如果你使用的是2的幂,那么这将完成通过选择m = ceil((log x)/ 2)。

希望这有帮助!

答案 4 :(得分:0)

在Python 2.7中:将此文件另存为Karatsuba.py

   def karatsuba(x,y):
        """Karatsuba multiplication algorithm.
        Return the product of two numbers in an efficient manner
        @author Shashank
        date: 23-09-2018

        Parameters
        ----------
        x : int
            First Number 
        y : int
            Second Number   

        Returns
        -------
        prod : int
               The product of two numbers 

        Examples
        --------
        >>> import Karatsuba.karatsuba
        >>> a = 1234567899876543211234567899876543211234567899876543211234567890
        >>> b = 9876543211234567899876543211234567899876543211234567899876543210
        >>> Karatsuba.karatsuba(a,b)
        12193263210333790590595945731931108068998628253528425547401310676055479323014784354458161844612101832860844366209419311263526900
        """
        if len(str(x)) == 1 or len(str(y)) == 1:
            return x*y
        else:
            n = max(len(str(x)), len(str(y)))
            m = n/2

            a = x/10**m
            b = x%10**m
            c = y/10**m
            d = y%10**m

            ac = karatsuba(a,c)                             #step 1
            bd = karatsuba(b,d)                             #step 2
            ad_plus_bc = karatsuba(a+b, c+d) - ac - bd      #step 3
            prod = ac*10**(2*m) + bd + ad_plus_bc*10**m     #step 4
            return prod