为什么pow(a,d,n)比** d%n快得多?

时间:2013-01-03 06:00:16

标签: python performance pypy

我试图实现一个Miller-Rabin primality test,并且很困惑为什么中等数字(~7位数)需要这么长时间(> 20秒)。我最终发现以下代码行是问题的根源:

x = a**d % n

(其中adn都相似,但不相等,中等数字,**是取幂运算符,%是模运算符)

然后我尝试用以下内容替换它:

x = pow(a, d, n)

相比之下它几乎是瞬间的。

对于上下文,这是原始函数:

from random import randint

def primalityTest(n, k):
    if n < 2:
        return False
    if n % 2 == 0:
        return False
    s = 0
    d = n - 1
    while d % 2 == 0:
        s += 1
        d >>= 1
    for i in range(k):
        rand = randint(2, n - 2)
        x = rand**d % n         # offending line
        if x == 1 or x == n - 1:
            continue
        for r in range(s):
            toReturn = True
            x = pow(x, 2, n)
            if x == 1:
                return False
            if x == n - 1:
                toReturn = False
                break
        if toReturn:
            return False
    return True

print(primalityTest(2700643,1))

示例定时计算:

from timeit import timeit

a = 2505626
d = 1520321
n = 2700643

def testA():
    print(a**d % n)

def testB():
    print(pow(a, d, n))

print("time: %(time)fs" % {"time":timeit("testA()", setup="from __main__ import testA", number=1)})
print("time: %(time)fs" % {"time":timeit("testB()", setup="from __main__ import testB", number=1)})

输出(使用PyPy 1.9.0运行):

2642565
time: 23.785543s
2642565
time: 0.000030s

输出(使用Python 3.3.0运行,2.7.2返回非常相似的时间):

2642565
time: 14.426975s
2642565
time: 0.000021s

还有一个相关问题,为什么这个计算在运行Python 2或3时的速度几乎是PyPy的两倍,而PyPy通常是much faster

4 个答案:

答案 0 :(得分:160)

请参阅modular exponentiation上的维基百科文章。基本上,当你执行a**d % n时,实际上你必须计算a**d,这可能非常大。但有一些方法可以计算a**d % n,而无需计算a**d本身,这就是pow的作用。 **运算符无法执行此操作,因为它无法“看到将来”知道您将立即采用模数。

答案 1 :(得分:37)

BrenBarn回答了你的主要问题。对你而言:

  

为什么使用Python 2或3运行时的速度几乎是PyPy的两倍,而PyPy通常要快得多?

如果您阅读PyPy的performance page,这正是PyPy不擅长的事实 - 事实上,这是他们给出的第一个例子:

  

错误的例子包括使用大长度进行计算 - 这是由不可优化的支持代码执行的。

理论上,将一个巨大的取幂,然后将mod转换为模幂运算(至少在第一次传递之后)是JIT可能做出的转换......但不是PyPy的JIT。

作为旁注,如果您需要使用大整数进行计算,您可能需要查看第三方模块,如gmpy,这有时可能比CPython的本机实现快得多,在某些情况下除了主流用途,还有许多额外的功能,否则你不得不自己写,但代价是不方便。

答案 2 :(得分:11)

有进行模幂运算的快捷方式:例如,您可以从a**(2i) mod ni找到每1log(d)并将它们相乘(mod n })你需要的中间结果。像3参数pow()这样的专用模幂运算函数可以利用这些技巧,因为它知道你正在进行模运算。鉴于裸表达式a**d % n,Python解析器无法识别它,因此它将执行完整计算(这将花费更长时间)。

答案 3 :(得分:3)

计算x = a**d % n的方式是将a提升为d幂,然后使用n对其进行模数化。首先,如果a很大,则会创建一个巨大的数字,然后将其截断。但是,x = pow(a, d, n)最有可能被优化,以便仅跟踪最后n个数字,这些都是计算乘法模数所需的全部数据。