无法使用带有* args参数的函数的jit编译函数

时间:2019-10-27 04:17:06

标签: python numba

我正在尝试编译一个包含numpy数组和元组的函数 使用numba的* arg形式的参数。

import numba as nb
import numpy as np

@nb.njit(cache=True)
def myfunc(t, *p):
    val = 0
    for j in range(0, len(p), 2):
        val += p[j]*np.exp(-p[j+1]*t)
    return val

T = np.arange(12)
pars = (1.0, 2.0, 3.0, 4.0)
mfunc = myfunc(T, *pars)

但是我得到这个结果

In [1]: run numba_test.py                                                                                                                                                                  
---------------------------------------------------------------------------                                                                                                                
TypingError                               Traceback (most recent call last)                                                                                                                
~/Programs/my-python/numba_test.py in <module>                                                                                                                                             
     12                                                                                                                                                                                    
     13 T = np.arange(12)                                                                                                                                                                  
---> 14 mfunc = myfunc(T, 1.0, 2.0, 3.0, 4.0)                                                                                                                                              

...
...                                                                                                                                                                                   
TypingError: Failed in nopython mode pipeline (step: nopython frontend)                                                                                                                    
Invalid use of Function(<built-in function iadd>) with argument(s) of type(s): (Literal[int](0), array(float64, 1d, C))                                                                    
Known signatures:                                                                                                                                                                          
 * (int64, int64) -> int64                                                                                                                                                                 
 * (int64, uint64) -> int64                                                                                                                                                                
 * (uint64, int64) -> int64                                                                                                                                                                
 * (uint64, uint64) -> uint64                                                                                                                                                              
 * (float32, float32) -> float32                                                                                                                                                           
 * (float64, float64) -> float64                                                                                                                                                           
 * (complex64, complex64) -> complex64                                                                                                                                                     
 * (complex128, complex128) -> complex128                                                                                                                                                  
 * parameterized                                                                                                                                                                           
In definition 0:                                                                                                                                                                           
    All templates rejected with literals.                                                                                                                                                  
...
...                                                                                                                                                                         
    All templates rejected without literals.                                                                                                                                               
This error is usually caused by passing an argument of a type that is unsupported by the named function.                                                                                   
[1] During: typing of intrinsic-call at /home/cshugert/Programs/my-python/numba_test.py (9)                                                                                                

File "numba_test.py", line 9:                                                                                                                                                              
def myfunc(t, *p):                                                                                                                                                                         
    <source elided>                                                                                                                                                                        
    for j in range(0, len(p), 2):                                                                                                                                                          
        val += p[j]*np.exp(-p[j+1]*t)                                                                                                                                                      
        ^                                                                                                                                                                                  

Numba确实支持使用元组,因此我认为可能会有 我可以在jit编译器中添加一些签名。但是,我不确定 确切地放在那里。 numba编译器是否可能是这种情况 无法处理带有* args作为参数的函数?我有什么办法可以使我的功能正常工作?

1 个答案:

答案 0 :(得分:1)

让我们再次看到错误消息

TypingError: Failed in nopython mode pipeline (step: nopython frontend)                                                                                                                    
Invalid use of Function(<built-in function iadd>) with argument(s)
 of type(s): (Literal[int](0), array(float64, 1d, C))                                                                    
Known signatures:                                                                                                                                                                          
 * (int64, int64) -> int64                                                                                                                                                                 
 * (int64, uint64) -> int64                                                                                                                                                                
 * (uint64, int64) -> int64                                                                                                                                                                
 * (uint64, uint64) -> uint64                                                                                                                                                              
 * (float32, float32) -> float32                                                                                                                                                           
 * (float64, float64) -> float64                                                                                                                                                           
 * (complex64, complex64) -> complex64                                                                                                                                                     
 * (complex128, complex128) -> complex128                                                                                                                                                  
 * parameterized  

该错误是<built-in function iadd>的错误,它是+。如果您查看该错误,则不是由于传递了*args,而是由于以下语句:

val += p[j]*np.exp(-p[j+1]*t)

基本上在提到的+的所有兼容类型中,它都不支持在integer上添加array(有关更多信息,请参见错误消息和已知签名)。

您可以通过使用val(请参见文档here)将np.zeros设置为零数组来解决此问题。

import numba as nb
import numpy as np

@nb.njit
def myfunc(t, *p):
    val = np.zeros(12) #<------------ Set it as an array of zeros
    for j in range(0, len(p), 2):
        val += p[j]*np.exp(-p[j+1]*t)
    return val

T = np.arange(12)
pars = (1.0, 2.0, 3.0, 4.0)
mfunc_val = myfunc(T, *pars)

您可以在this Google Colab notebook中查看代码。