我遇到了一段代码,它完成了我希望它比原始代码做得快得多的工作。然而,与我的原始代码不同,这个代码被numba jit函数减慢而不是加速。有没有人知道为什么会这样? 这是没有numba的代码:
def sum_factors(n):
result = []
for i in range(1, int(n**0.5) + 1):
if n % i == 0:
result.extend([i, n//i])
return sum(set(result)-set([n]))
def amicable_pair(number):
result = []
for x in range(1,number+1):
y = sum_factors(x)
if sum_factors(y) == x and x != y:
result.append(tuple(sorted((x,y))))
return set(result)
print(amicable_pair(100000))
这是带有numba函数的代码:
from numba import jit
@jit
def sum_factors(n):
result = []
for i in range(1, int(n**0.5) + 1):
if n % i == 0:
result.extend([i, n//i])
return sum(set(result)-set([n]))
@jit
def amicable_pair(number):
result = []
for x in range(1,number+1):
y = sum_factors(x)
if sum_factors(y) == x and x != y:
result.append(tuple(sorted((x,y))))
return set(result)
print(amicable_pair(100000))
第一个代码在jupyter笔记本中运行需要1.7秒,第二个代码在jupyter笔记本中需要6.5秒。
答案 0 :(得分:1)
你必须采用你的代码进行jit编译:
@numba.njit
def sum_factors(n):
result = 1
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
result += i + n//i
return result
def amicable_pair(number):
result = []
for x in range(1,number+1):
y = sum_factors(x)
if sum_factors(y) == x and x != y:
result.append(tuple(sorted((x,y))))
return set(result)
print(amicable_pair(100000))