我可以为numba.njit提供一些参数类型,但让其推断其余参数吗?

时间:2019-12-23 20:44:08

标签: python-2.7 numba

我正在使用numba.njit;它可以很容易地推断出H1()的类型,但是不能推断出H2()的类型,因为我在其中提供函数F作为参数之一。

有没有一种方法可以告诉numba.njit F的类型,但是可以推断其余的类型,所以我不必提供整体签名?

当我运行以下代码时...

import numpy as np
import numba 


@numba.njit
def F1(s):
    return 1/s

@numba.njit
def H1(s, p):
    return F1(s)/(F1(s)+p['tau'])

@numba.njit
def H2(s, p, F):
    return F(s)/(F(s)+p['tau'])

def prepare_params(x=None):
    try:
        shape = x.shape
    except LookupError:
        shape = ()
    f64 = np.dtype(np.float64)
    p = np.zeros(shape=shape, dtype=[('tau',f64),
                                     ('something',f64)])
    p['tau'] = 0.001
    p['something'] = 1
    return p

s = np.logspace(0,2,5)*1j
p = prepare_params(s)
print "H1=", H1(s,p)
print "H2=", H2(s,p,F1)

我明白了:

H1= [ 0.99999900-0.001j       0.99999000-0.00316225j  0.99990001-0.009999j
  0.99900100-0.03159119j  0.99009901-0.0990099j ]
H2=
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-18-2a1359160fe9> in <module>()
     30 p = prepare_params(s)
     31 print "H1=", H1(s,p)
---> 32 print "H2=", H2(s,p,F1)

c:\app\python\anaconda\2\lib\site-packages\numba\dispatcher.pyc in _compile_for_args(self, *args, **kws)
    328                                 for i, err in failed_args))
    329                 e.patch_message(msg)
--> 330             raise e
    331 
    332     def inspect_llvm(self, signature=None):

TypingError: Caused By:
Traceback (most recent call last):
  File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 238, in run
    stage()
  File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 452, in stage_nopython_frontend
    self.locals)
  File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 865, in type_inference_stage
    infer.propagate()
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 844, in propagate
    raise errors[0]
TypingError: Internal error at <numba.typeinfer.ArgConstraint object at 0x00000000089DD320>:
--%<-----------------------------------------------------------------
Traceback (most recent call last):
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 137, in propagate
    constraint(typeinfer)
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 190, in __call__
    typeinfer.add_type(self.dst, ty, loc=self.loc)
  File "c:\app\python\anaconda\2\lib\contextlib.py", line 35, in __exit__
    self.gen.throw(type, value, traceback)
  File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 265, in new_error_context
    six.reraise(type(newerr), newerr, sys.exc_info()[2])
  File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 259, in new_error_context
    yield
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 189, in __call__
    assert ty.is_precise()
InternalError: 
[1] During: typing of argument at <ipython-input-18-2a1359160fe9> (15)
--%<-----------------------------------------------------------------

File "<ipython-input-18-2a1359160fe9>", line 15

Failed at nopython (nopython frontend)
Internal error at <numba.typeinfer.ArgConstraint object at 0x00000000089DD320>:
--%<-----------------------------------------------------------------
Traceback (most recent call last):
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 137, in propagate
    constraint(typeinfer)
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 190, in __call__
    typeinfer.add_type(self.dst, ty, loc=self.loc)
  File "c:\app\python\anaconda\2\lib\contextlib.py", line 35, in __exit__
    self.gen.throw(type, value, traceback)
  File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 265, in new_error_context
    six.reraise(type(newerr), newerr, sys.exc_info()[2])
  File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 259, in new_error_context
    yield
  File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 189, in __call__
    assert ty.is_precise()
InternalError: 
[1] During: typing of argument at <ipython-input-18-2a1359160fe9> (15)
--%<-----------------------------------------------------------------

File "<ipython-input-18-2a1359160fe9>", line 15

This error may have been caused by the following argument(s):
- argument 2: cannot determine Numba type of <class 'numba.targets.registry.CPUDispatcher'>

1 个答案:

答案 0 :(得分:0)

没关系,当我写问题时,我使用的是Numba 0.35;我已经升级到0.46,它可以正常工作:

H1= [ 0.99999900-0.001j       0.99999000-0.00316225j  0.99990001-0.009999j
  0.99900100-0.03159119j  0.99009901-0.0990099j ]
H2= [ 0.99999900-0.001j       0.99999000-0.00316225j  0.99990001-0.009999j
  0.99900100-0.03159119j  0.99009901-0.0990099j ]

我发现了this entry in the FAQ

  

我可以将函数作为实参传递给参数吗?

     

从Numba 0.39开始,只要函数参数也已JIT编译,就可以: