Python中使用numba的简单求和函数无法计算

时间:2015-10-07 13:46:10

标签: python numba

我正在尝试学习Python和Numba,我无法弄清楚为什么以下代码无法在IPython / Jupyter中计算:

from numba import *

sample_array = np.arange(10000.0)

@jit('float64(float64, float64)')
def sum(x, y):
    return x + y

sum(sample_array, sample_array)
  

TypeError Traceback(最近一次调用最后一次)    in()   ----> 1 sum(sample_array,sample_array)

     

C:\ Users *** \ AppData \ Local \ Continuum \ Anaconda \ lib \ site-packages \ numba \ dispatcher.pyc in _explain_matching_error(self,* args,** kws)       201 msg =(“参数类型没有匹配的定义%s”       202%','。join(map(str,args)))    - > 203引发TypeError(msg)       204       205 def repr (自我):

     

TypeError:参数类型数组(float64,1d,C),数组(float64,1d,C)没有匹配的定义

2 个答案:

答案 0 :(得分:4)

您正在传入数组,但您的jit签名需要标量浮点数。请尝试以下方法:

@jit('float64[:](float64[:], float64[:])')
def sum(x, y):
    return x + y

我的建议是看看你是否可以不指定类型而只是使用裸@jit装饰器,它将在运行时进行类型推断,你可以更灵活地处理输入。例如:

@jit(nopython=True)
def sum(x, y):
    return x + y

In [13]: sum(1,2)
Out[13]: 3

In [14]: sum(np.arange(5),np.arange(5))
Out[14]: array([0, 2, 4, 6, 8])

我的经验是,添加类型很少会带来任何性能上的好处。

答案 1 :(得分:0)

就我而言,这是因为我传递的是二维数组(矩阵),但它期望的是一维数组(向量)