Numba减慢了我的计划,而不是加快速度

时间:2018-02-26 12:29:47

标签: python-3.x numba

我遇到了一段代码,它完成了我希望它比原始代码做得快得多的工作。然而,与我的原始代码不同,这个代码被numba jit函数减慢而不是加速。有没有人知道为什么会这样? 这是没有numba的代码:

def sum_factors(n):  
    result = []
    for i in range(1, int(n**0.5) + 1):
        if n % i == 0:
            result.extend([i, n//i])
    return sum(set(result)-set([n]))

def amicable_pair(number):
    result = []
    for x in range(1,number+1):
        y = sum_factors(x)
        if sum_factors(y) == x and x != y:
            result.append(tuple(sorted((x,y))))
    return set(result)
print(amicable_pair(100000))

这是带有numba函数的代码:

from numba import jit
@jit
def sum_factors(n):  
    result = []
    for i in range(1, int(n**0.5) + 1):
        if n % i == 0:
            result.extend([i, n//i])
    return sum(set(result)-set([n]))
@jit
def amicable_pair(number):
    result = []
    for x in range(1,number+1):
        y = sum_factors(x)
        if sum_factors(y) == x and x != y:
            result.append(tuple(sorted((x,y))))
    return set(result)
print(amicable_pair(100000))

第一个代码在jupyter笔记本中运行需要1.7秒,第二个代码在jupyter笔记本中需要6.5秒。

1 个答案:

答案 0 :(得分:1)

你必须采用你的代码进行jit编译:

@numba.njit
def sum_factors(n):  
    result = 1
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            result += i + n//i
    return result

def amicable_pair(number):
    result = []
    for x in range(1,number+1):
        y = sum_factors(x)
        if sum_factors(y) == x and x != y:
            result.append(tuple(sorted((x,y))))
    return set(result)
print(amicable_pair(100000))