我正在使用numba.njit;它可以很容易地推断出H1()
的类型,但是不能推断出H2()
的类型,因为我在其中提供函数F
作为参数之一。
有没有一种方法可以告诉numba.njit
F
的类型,但是可以推断其余的类型,所以我不必提供整体签名?
当我运行以下代码时...
import numpy as np
import numba
@numba.njit
def F1(s):
return 1/s
@numba.njit
def H1(s, p):
return F1(s)/(F1(s)+p['tau'])
@numba.njit
def H2(s, p, F):
return F(s)/(F(s)+p['tau'])
def prepare_params(x=None):
try:
shape = x.shape
except LookupError:
shape = ()
f64 = np.dtype(np.float64)
p = np.zeros(shape=shape, dtype=[('tau',f64),
('something',f64)])
p['tau'] = 0.001
p['something'] = 1
return p
s = np.logspace(0,2,5)*1j
p = prepare_params(s)
print "H1=", H1(s,p)
print "H2=", H2(s,p,F1)
我明白了:
H1= [ 0.99999900-0.001j 0.99999000-0.00316225j 0.99990001-0.009999j
0.99900100-0.03159119j 0.99009901-0.0990099j ]
H2=
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
<ipython-input-18-2a1359160fe9> in <module>()
30 p = prepare_params(s)
31 print "H1=", H1(s,p)
---> 32 print "H2=", H2(s,p,F1)
c:\app\python\anaconda\2\lib\site-packages\numba\dispatcher.pyc in _compile_for_args(self, *args, **kws)
328 for i, err in failed_args))
329 e.patch_message(msg)
--> 330 raise e
331
332 def inspect_llvm(self, signature=None):
TypingError: Caused By:
Traceback (most recent call last):
File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 238, in run
stage()
File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 452, in stage_nopython_frontend
self.locals)
File "c:\app\python\anaconda\2\lib\site-packages\numba\compiler.py", line 865, in type_inference_stage
infer.propagate()
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 844, in propagate
raise errors[0]
TypingError: Internal error at <numba.typeinfer.ArgConstraint object at 0x00000000089DD320>:
--%<-----------------------------------------------------------------
Traceback (most recent call last):
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 137, in propagate
constraint(typeinfer)
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 190, in __call__
typeinfer.add_type(self.dst, ty, loc=self.loc)
File "c:\app\python\anaconda\2\lib\contextlib.py", line 35, in __exit__
self.gen.throw(type, value, traceback)
File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 265, in new_error_context
six.reraise(type(newerr), newerr, sys.exc_info()[2])
File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 259, in new_error_context
yield
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 189, in __call__
assert ty.is_precise()
InternalError:
[1] During: typing of argument at <ipython-input-18-2a1359160fe9> (15)
--%<-----------------------------------------------------------------
File "<ipython-input-18-2a1359160fe9>", line 15
Failed at nopython (nopython frontend)
Internal error at <numba.typeinfer.ArgConstraint object at 0x00000000089DD320>:
--%<-----------------------------------------------------------------
Traceback (most recent call last):
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 137, in propagate
constraint(typeinfer)
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 190, in __call__
typeinfer.add_type(self.dst, ty, loc=self.loc)
File "c:\app\python\anaconda\2\lib\contextlib.py", line 35, in __exit__
self.gen.throw(type, value, traceback)
File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 265, in new_error_context
six.reraise(type(newerr), newerr, sys.exc_info()[2])
File "c:\app\python\anaconda\2\lib\site-packages\numba\errors.py", line 259, in new_error_context
yield
File "c:\app\python\anaconda\2\lib\site-packages\numba\typeinfer.py", line 189, in __call__
assert ty.is_precise()
InternalError:
[1] During: typing of argument at <ipython-input-18-2a1359160fe9> (15)
--%<-----------------------------------------------------------------
File "<ipython-input-18-2a1359160fe9>", line 15
This error may have been caused by the following argument(s):
- argument 2: cannot determine Numba type of <class 'numba.targets.registry.CPUDispatcher'>
答案 0 :(得分:0)
没关系,当我写问题时,我使用的是Numba 0.35;我已经升级到0.46,它可以正常工作:
H1= [ 0.99999900-0.001j 0.99999000-0.00316225j 0.99990001-0.009999j
0.99900100-0.03159119j 0.99009901-0.0990099j ]
H2= [ 0.99999900-0.001j 0.99999000-0.00316225j 0.99990001-0.009999j
0.99900100-0.03159119j 0.99009901-0.0990099j ]
我可以将函数作为实参传递给参数吗?
从Numba 0.39开始,只要函数参数也已JIT编译,就可以: