我正在研究一种受监督的Kmeans算法(一些质心是固定的,永远不会移动,其他行为正常)我正在尝试使用Numba来加快速度。
由于Numbas对理解的支持充其量是不稳定的,在最坏的情况下基本上不存在,我恢复使用普通的for循环,并希望Numba能够加快速度。
我的代码如下所示:
# This function works as intended.
@jit(nopython = True)
def distsqr(p1: Point, p2: Point) -> Numeric:
"""
Euclidean distance squared between 2 points in R^3.
"""
return ((p1 - p2) ** 2).sum()
@jit(nopython = True)
def kpp_init(points: PointArray, k: int, fixed_centroids: PointArray) -> PointList:
"""
The KMeans++ initialization method using some given fixed centroids.
"""
mobile_centroids = []
all_centroids = list(fixed_centroids)
points_left = list(points)
if not len(all_centroids):
new_centroid = numpy.random.choice(array(list(range(len(points_left)))))
all_centroids += [points_left[new_centroid]]
mobile_centroids += [points_left[new_centroid]]
del points_left[new_centroid]
while len(all_centroids) < k:
# For each point, distance to the nearest centroid, squared.
# dxsqr = [min(distsqr(centroid, point) for centroid in all_centroids) for point in points_left]
dxsqr = []
for point in points_left:
dist = []
for centroid in all_centroids:
dist += [distsqr(point, centroid)]
dxsqr += [min(dist)]
dxsqr_sum = numpy.sum(dxsqr)
# Not working. Numba wants loops.
# weights = [dxsqr_val / dxsqr_sum for dxsqr_val in dxsqr]
weights = []
for dxsqr_val in dxsqr:
weights += [dxsqr_val / dxsqr_sum]
new_centroid = numpy.random.choice(array(list(range(len(points_left)))), p = weights)
all_centroids += [points_left[new_centroid]]
mobile_centroids += [points_left[new_centroid]]
del points_left[new_centroid]
break
return mobile_centroids
现在每当我尝试测试这个numba时,我都会有一个不那么可怕的错误:
File "E:\Tools\Python36\lib\site-packages\numba\typeinfer.py", line 888, in check_var
raise TypingError("Can't infer type of variable '%s': %s" % (var, tp))
numba.errors.TypingError: Failed at nopython (nopython frontend)
Can't infer type of variable 'dist': list(undefined)
这是一个非常容易理解的错误,但我不知道如何解决它。
所以我的问题是,是否有某种机制可以让numba成为dist
的类型?它只会包含距离,所以我想类型将是list(float32)
。如果没有这样的东西我怎么能让numba玩好并编译我的函数?
P.S。其他类型的注释基本上都是numpy.ndarray
以某种形式存在,我怀疑它们是否有任何影响,我把它们放在那里给自己,但以防万一呢:
Numeric = Union[numpy.array, numpy.dtype, int, float] # Numeric
NumericArray = numpy.ndarray # List[Numeric]
Point = NumericArray # Array[Numeric, Numeric, Numeric]
PointArray = NumericArray # List[Point]
PointList = List[Point]