在numba中将数组作为函数参数

时间:2019-02-25 13:04:35

标签: python numba

以下简单示例失败,并显示错误:

独立模块:

from numba.pycc import CC

cc = CC('foo')

@cc.export('product','float64(float64[:], float64[:])')
def product(a, b):
    prod = 0
    for i in range(a.size):
        prod += a[i] * b[i]
    return prod

if __name__ == "__main__":
    cc.compile()

测试程序:

import numpy as np
import foo

x = np.array([2,3,1,0])
y = np.array([2,3,1,0])

print(foo.product(x,y))

失败,并显示错误消息:

Traceback (most recent call last):
  File "\temp\test.py", line 7, in <module>
    print(foo.product(x,y))
SystemError: exception RuntimeError<class 'BytesWarning'> not a BaseException subclass

在Windows上使用的numba版本是0.42.0和Python 3.7.2。 有提示吗?

1 个答案:

答案 0 :(得分:0)

所以,我终于让您的代码可以工作了:

from numba.pycc import CC

cc = CC('foo')
cc.verbose = True
@cc.export('producti','int64(int64[:], int64[:])')  #<--- Your data type was wrong
def product(a, b):
    prod = 0
    for i in range(a.size):
        y = a[i] * b[i]
        prod += y
    return prod

if __name__ == "__main__":
    cc.compile()

测试上述功能的代码:

import numpy as np
import foo

x = np.array([2, 3, 1, 0])
y = np.array([2, 3, 1, 0])

print(foo.producti(x, y))   # Output : 14

一些要注意的地方

  • 创建xy数组的方式中,默认将dtype设置为int64,因此当您将其类型转换为float64时错误地转换。
  

print(x.dtype)

     

输出:dtype('int64')

  • 因此,只需将类型固定为int64就可以了(或者您可以根据需要使用i8作为速记)。

  • 使用运行代码链接到Google Colab笔记本Notebook Link

参考