如何使用numba正确编译此函数?

时间:2019-12-02 11:49:08

标签: python numpy numba

我正在尝试使用numba来编译一个相对简单的递归函数,该函数通过将每行计算为前几行的线性组合来生成矩阵,我将其表示为内积,并将向量vec截短为所需长度:

    @numba.jit("float64[:,:](float64[:], int32)")
    def r(vec, J):
        t = np.empty([len(vec), J], dtype=np.float64) # final matrix has shape len(vec) x J
        t[0] = vec[0]*np.arange(J) # first row
        for i in range(1, len(vec)):
            t[i] = np.sum(vec[:i, None]*t[:i], axis=0)
        return t

    >>> "NumbaWarning: Compilation is falling back to object mode WITH looplifting enabled 
        because Function "r" failed type inference due to: Invalid use of
        Function(<built-in function empty>) with argument(s) of type(s):
        (list(int64), dtype=class(float64))"

我不确定这是因为我没有正确使用jit中的类型特征,还是因为Numba不喜欢我编写函数的方式。

0 个答案:

没有答案