加速函数,将函数作为numba的参数

时间:2017-08-31 08:22:06

标签: python python-3.x numba

我正在尝试使用numba来加速一个将另一个函数作为参数的函数。最小的例子如下:

import numba as nb

def f(x):
    return x*x

@nb.jit(nopython=True)
def call_func(func,x):
    return func(x)

if __name__ == '__main__':
    print(call_func(f,5))
然而,这并不起作用,因为显然numba不知道如何处理该函数参数。回溯很长:

Traceback (most recent call last):
  File "numba_function.py", line 15, in <module>
    print(call_func(f,5))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
    raise e
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 579, in compile
    cres = self._compiler.compile(args, return_type)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 80, in compile
    flags=flags, locals=self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 740, in compile_extra
    return pipeline.compile_extra(func)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 360, in compile_extra
    return self._compile_bytecode()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 699, in _compile_bytecode
    return self._compile_core()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 686, in _compile_core
    res = pm.run(self.status)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 246, in run
    raise patched_exception
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 238, in run
    stage()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 452, in stage_nopython_frontend
    self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 841, in type_inference_stage
    infer.propagate()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 773, in propagate
    raise errors[0]
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 129, in propagate
    constraint(typeinfer)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 380, in __call__
    self.resolve(typeinfer, typevars, fnty)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 402, in resolve
    raise TypingError(msg, loc=self.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Invalid usage of pyobject with parameters (int64)
No type info available for pyobject as a callable.
File "numba_function.py", line 10
[1] During: resolving callee type: pyobject
[2] During: typing of call at numba_function.py (10)

This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class 'function'>

有没有办法解决这个问题?

2 个答案:

答案 0 :(得分:4)

这取决于您传递给func的{​​{1}}是否可以call_func模式进行编译。

如果它不能以nopython模式编译,那么它是不可能的,因为numba不支持nopython函数内的python调用(这就是它的原因)叫做nopython)。

但是如果它可以在nopython模式下编译,你可以使用闭包:

nopython

这种方法有一些明显的缺点,因为每次调用import numba as nb def f(x): return x*x def call_func(func, x): func = nb.njit(func) # compile func in nopython mode! @nb.njit def inner(x): return func(x) return inner(x) if __name__ == '__main__': print(call_func(f,5)) 时都需要编译funcinner。这意味着只有通过编译函数的加速比编译成本大,它才有可行。如果多次使用相同的函数调用call_func,则可以减轻这种开销:

call_func

只是一般性说明:我不会创建带有函数参数的numba函数。如果你不能对函数进行硬编码,那么numba不能生成非常快的函数,如果你还包括闭包的编译成本,那么它大多不值得。

答案 1 :(得分:2)

根据错误消息的建议,Numba无法处理table1类型的值。您可以查看the documentation Numba可以使用的类型。原因是Numba一般不能在function模式下优化(jit-compile)任意函数,它们基本上被认为是一个黑盒子(实际上,传递的函数甚至可以是原生函数!)。

通常的方法是让Numba改为优化被调用函数。如果您无法将装饰器添加到该函数中(例如,因为它不是源代码的一部分),您仍然可以手动使用它,如:

noptyhon

显然,如果Numba无法编译import numba as nb def f(x): return x*x if __name__ == '__main__': f_opt = nb.jit(nopython=True)(f) print(f_opt(5)) ,它仍会失败,但在这种情况下,你无论如何都无法做到。