在numba的jit编译时迭代一个元组

时间:2018-09-27 12:15:10

标签: python numba

在编译时,numba jitted函数中是否有一种方法可以评估函数元组(或列表)中的每个函数?

请注意,这个问题是关于如何在编译时使用Python循环构建jit代码,而不是在运行时在元组上进行迭代,这我不知道。

>

下面是一个完全无效的示例,但其核心是以下内容有效:

@jit(nopython=True)
def do_stuff(func_tuple):
    results = []
    results.append(func_tuple[0]())
    results.append(func_tuple[1]())
    results.append(func_tuple[2]())
    results.append(func_tuple[3]())
    return results

但以下内容则不行:

@jit(nopython=True)
def do_stuff_2(func_tuple):
    results = []
    for i in range(4):
        results.append(func_tuple[i]())
    return results

错误消息如下,其含义非常清楚:在运行时不支持索引到这样的元组。

Invalid usage of getitem with parameters ((type(CPUDispatcher(<function f1 at 0x116968268>)), type(CPUDispatcher(<function f2 at 0x1169688c8>)), type(CPUDispatcher(<function f3 at 0x1169a1b70>)), type(CPUDispatcher(<function f4 at 0x1169a1f28>))), int64)
 * parameterized
[1] During: typing of intrinsic-call at numba_minimal_not_working_example_2.py (36)

File "numba_minimal_not_working_example_2.py", line 36:
def do_stuff_2(func_tuple):
    <source elided>
    for i in range(4):
        results.append(func_tuple[i]())
  ^

但是,我只需要在编译时进行索引-我基本上只想生成类似于do_stuff的函数,但是要根据元组中的元素数自动生成。

原则上,这可以在编译时发生,因为numba认为元组的长度是其类型的一部分。但是我还无法解决该怎么做。我尝试了各种涉及递归和/或@generated_jit装饰器的技巧,但是我还没有设法解决一些有用的问题。有没有办法做到这一点?

这是完整的示例:

from numba import jit

@jit(nopython=True)
def f1():
    return 1

@jit(nopython=True)
def f2():
    return 2

@jit(nopython=True)
def f3():
    return 3

@jit(nopython=True)
def f4():
    return 4

func_tuple = (f1, f2, f3, f4)

# this works:
@jit(nopython=True)
def do_stuff(func_tuple):
    results = []
    results.append(func_tuple[0]())
    results.append(func_tuple[1]())
    results.append(func_tuple[2]())
    results.append(func_tuple[3]())
    return results

# but this does not:
@jit(nopython=True)
def do_stuff_2(func_tuple):
    results = []
    for i in range(4):
        results.append(func_tuple[i]())
    return results

# this doesn't either (similar error to do_stuff_2).
@jit(nopython=True)
def do_stuff_3(func_tuple):
    results = [f() for f in func_tuple]
    return results


print(do_stuff(func_tuple)) # prints '[1, 2, 3, 4]'

print(do_stuff_2(func_tuple)) # gives the error above

#print(do_stuff_3(func_tuple)) # gives a similar error

1 个答案:

答案 0 :(得分:1)

这实际上是已知的limitation of Numbaget的回溯中也提到了这一点。

基本上,当您要求@jit您的函数时,Numba无法正确推断已编译代码的类型。

一种解决方法是在@jit(nopython=False)上使用do_stuff_2(),这样便可以通过使用Python对象系统来处理此类代码。 相反,您将无法@jit使用do_stuff_3()函数,甚至不能使用nopython=False,因为numba不支持理解(至少在0.39.0版之前) )。