用于计算两组矢量之间的距离矩阵的简单函数,如
import numpy as np
from numba import jit
@jit(signature_or_function='(f8[:,:], f8[:,:])', nopython=True, cache=True, locals={'d': numba.float64[:]})
def foo(prototypes, features):
protonum = prototypes.shape[0]
featurenum = features.shape[0]
dismatrix = np.zeros(shape=(featurenum, protonum), dtype=np.double)
for i in range(featurenum):
feature = features[i,:]
tmp = (prototypes - feature)**2
d = np.sqrt(tmp.sum(axis=1)) # (nproto, 1)
dismatrix[i,:] = d
idx = d.argmin()
return dismatrix
输入参数'原型'是ndarray形状(质子,暗),'功能'是ndarray形状(featurenum,昏暗)。如果&nbspthon' nopython'设置为False,但如果&nbspthon' nopython'设置为True。错误消息是
numba.errors.TypingError: Failed at nopython (nopython frontend)
No conversion from float64 to array(float64, 1d, A) for 'd'
似乎numba推断出' d'错误的是' float64'而不是float64数组,即使由' locals = {' d':numba.float64 [:]}'明确指定。我有什么地方做错了吗?
Numba版本= 0.23.1,Python 3.5。