我需要为通用指标构建一个相异矩阵。由于我需要算法快速运行,因此我在numba
模式下使用了nopython
0.35。
这是我的代码
import numpy as np
from numba import jit
from jellyfish import levenshtein_distance
def _dissimilarity_matrix(metric):
@jit(nopython=True)
def dm(data):
n = data.shape[0]
diss = np.zeros((n, n))
for i in range(n):
for j in range(i+1):
dist = metric(data[i], data[j])
diss[i, j] = dist
diss[j, i] = dist
return diss
return dm
@jit(nopython=True)
def euclidean_distance(vec1, vec2):
return np.sqrt(((vec1 - vec2)**2).sum())
test1 = np.random.randn(10, 2)
dissimilarity_matrix1 = _dissimilarity_matrix(euclidean_distance)
diss1 = dissimilarity_matrix1(test1)
test2 = np.array(["this", "is", "a", "test"])
dissimilarity_matrix2 = _dissimilarity_matrix(levenshtein_distance)
diss2 = dissimilarity_matrix2(test2)
但输出是:
numba.errors.TypingError: Failed at nopython (nopython frontend)
Untyped global name 'metric': cannot determine Numba type of <class 'builtin_function_or_method'>
File "test.py", line 12
请注意,函数euclidean_distance
由我定义并具有装饰器@jit(nopython=True)
,而函数levenshtein_distance
来自外部模块(不是由我编写)。有没有办法明确告诉numba
传入的函数的签名(即metric
中的_dissimilarity_matrix
)?
我真的需要函数_dissimilarity_matrix
以nopython
模式运行并接受任意函数作为输入。
答案 0 :(得分:2)
当metric
为euclidean_distance
时,您的代码适合我,因为这是一个nopython
jitted numba函数的函数。但是你不能传递任意函数。为了使某些内容在nopython
模式下工作,numba必须支持每个可调用函数(请参阅http://numba.pydata.org/numba-doc/latest/reference/pysupported.html和http://numba.pydata.org/numba-doc/latest/reference/numpysupported.html)或用户定义为numba nopython
函数。没有解决这个限制。