我正在尝试使用numba来编译一个相对简单的递归函数,该函数通过将每行计算为前几行的线性组合来生成矩阵,我将其表示为内积,并将向量vec
截短为所需长度:
@numba.jit("float64[:,:](float64[:], int32)")
def r(vec, J):
t = np.empty([len(vec), J], dtype=np.float64) # final matrix has shape len(vec) x J
t[0] = vec[0]*np.arange(J) # first row
for i in range(1, len(vec)):
t[i] = np.sum(vec[:i, None]*t[:i], axis=0)
return t
>>> "NumbaWarning: Compilation is falling back to object mode WITH looplifting enabled
because Function "r" failed type inference due to: Invalid use of
Function(<built-in function empty>) with argument(s) of type(s):
(list(int64), dtype=class(float64))"
我不确定这是因为我没有正确使用jit
中的类型特征,还是因为Numba不喜欢我编写函数的方式。