Nopython模式下Numba的递归函数错误

时间:2019-04-08 16:33:02

标签: python python-3.x recursion tuples numba

我想使用nopython模式在Numba中运行递归函数。到目前为止,我只遇到错误。这是一个非常简单的代码,用户给出一个少于五个元素的元组,然后该函数创建另一个元组,并在元组中添加一个新值(在本例中为数字3)。重复此过程,直到最后一个元组的长度为5。由于某种原因,这不起作用,不知道为什么。

@njit
def tup(a):
    if len(a) == 5:
        return a
    else:
        b = a + (3,)
        b = tup(b)
        return b

例如,如果为a = (0,1),我希望最终结果为元组(0,1,3,3,3)

编辑:我正在使用Numba 0.41.0,并且我得到的错误是内核快死了,'内核似乎已经死了。它将自动重新启动。'

2 个答案:

答案 0 :(得分:1)

根据当前版本中的this list of proposals

  

numba中的递归支持目前仅限于使用   该函数的显式类型注释。此限制来自   无法确定递归调用的返回类型。

因此,请尝试:

from numba import jit

@jit()
def tup(a:tuple) -> tuple:
    if len(a) == 5:
        return a

    return tup(a + (3,))

print(tup((0, 1)))

看看这对您是否更好。

答案 1 :(得分:1)

您不应该这样做的原因有很多:

  • 通常这是一种方法,在纯Python中可能比在数字装饰函数中更快。
  • 迭代将更简单并且可能更快,但是请注意,即使在numba中,连接元组通常也是First Name Last Name Offer Status John Smith Declined Jane Anderson Accepted 操作。因此,该函数的整体性能为O(n)。这可以通过使用支持O(n**2)追加的数据结构或支持预分配大小的数据结构来改善。或者只是不使用“循环”或“递归”方法。
  • 您是否尝试过省略O(1)装饰器并传入包含6个元素的元组会发生什么情况? (提示:它将达到递归限制,因为它永远不会满足递归的结束条件。)

在编写0.43.1时,Numba仅在参数类型在两次递归之间不变时才支持简单递归。在您的情况下,类型确实发生了变化,您传入了njit,但是递归调用尝试传入了另一种类型的tuple(int64 x 2)。奇怪的是,它遇到了我计算机上的tuple(int64 x 3)-似乎是numba中的错误。

我的建议是使用此代码(无编号,无递归):

StackOverflow

还会返回预期结果:

def tup(a):
    if len(a) < 5:
        a += (3, ) * (5 - len(a))
    return a