Numba jit和延迟类型

时间:2019-08-24 17:14:35

标签: python numba

我正在传递numba作为我的函数的签名

@numba.jit(numba.types.UniTuple(numba.float64[:, :], 2)(
    numba.float64[:, :], numba.float64[:, :], numba.float64[:, :], 
earth_model_type))

其中earth_model_type被定义为

earth_model_type = numba.deferred_type()
earth_model_type.define(em.EarthModel.class_type.instance_type)

它可以很好地编译,但是当我尝试调用该函数时,我得到了

  

*** TypeError:参数类型数组(float64、2d,F),数组(float64、2d,C),数组(float64、2d,F)没有匹配的定义,   instance.jitclass.EarthModel#7fd9c48dd668

在我看来,具有不匹配定义的参数类型与上面的类型几乎相同。另一方面,如果我不仅仅通过使用@numba.jit(nopython=True)来指定签名,它就可以正常工作,并且numba编译的函数的签名是

ipdb> numbed_cowell_propagator_propagate.signatures                   
  

[(array(float64,2d,F),array(float64,2d,C),array(float64,2d,F),   instance.jitclass.EarthModel#7f81bbc0e780)]

编辑

如果我使用FAQ中的方式强制执行C阶数组,我仍然会收到错误消息

  

TypeError:参数类型数组(float64,   2d,C),array(float64,2d,C),array(float64,2d,C),   instance.jitclass.EarthModel#7f6edd8d57b8

我非常确定问题与延迟类型有关,因为如果我不传递jit类,而是传递该类(4个numba.float64)中需要的所有属性,则可以正常工作。

指定签名时我在做什么错了?

干杯。

1 个答案:

答案 0 :(得分:0)

在不完全了解完整代码如何工作的情况下,我不确定为什么需要使用延迟类型。通常,它用于包含相同类型实例变量(例如链表或其他节点树)的jit类,因此需要推迟到编译器处理类本身之后再参见(请参见{{3} })以下最小示例起作用(如果我使用延迟类型,则可以重现您的错误):

import numpy as np
import numba as nb

spec = [('x', nb.float64)]

@nb.jitclass(spec)
class EarthModel:
    def __init__(self, x):
        self.x = x

earth_model_type = EarthModel.class_type.instance_type

@nb.jit(nb.float64(nb.float64[:, :], nb.float64[:, :], nb.float64[:, :], earth_model_type))
def test(x, y, z, em):
    return em.x

然后运行它:

em = EarthModel(9.9)
x = np.random.normal(size=(3,3))
y = np.random.normal(size=(3,3))
z = np.random.normal(size=(3,3))

res = test(x, y, z, em)
print(res)  # 9.9