为什么numba推断类型错误?

时间:2016-02-24 05:32:32

标签: python numba

用于计算两组矢量之间的距离矩阵的简单函数,如

import numpy as np
from numba import jit

@jit(signature_or_function='(f8[:,:], f8[:,:])',  nopython=True, cache=True, locals={'d': numba.float64[:]})
def foo(prototypes, features):
    protonum   = prototypes.shape[0]
    featurenum = features.shape[0]
    dismatrix = np.zeros(shape=(featurenum, protonum), dtype=np.double)
    for i in range(featurenum):
        feature = features[i,:]
        tmp = (prototypes - feature)**2
        d = np.sqrt(tmp.sum(axis=1))   # (nproto, 1)
        dismatrix[i,:] = d
        idx = d.argmin()
    return dismatrix

输入参数'原型'是ndarray形状(质子,暗),'功能'是ndarray形状(featurenum,昏暗)。如果&nbspthon' nopython'设置为False,但如果&nbspthon' nopython'设置为True。错误消息是

numba.errors.TypingError: Failed at nopython (nopython frontend)
No conversion from float64 to array(float64, 1d, A) for 'd'

似乎numba推断出' d'错误的是' float64'而不是float64数组,即使由' locals = {' d':numba.float64 [:]}'明确指定。我有什么地方做错了吗?

Numba版本= 0.23.1,Python 3.5。

0 个答案:

没有答案