为Numba输入提示任意变量

时间:2018-03-18 18:44:16

标签: python numpy type-inference jit numba

我正在研究一种受监督的Kmeans算法(一些质心是固定的,永远不会移动,其他行为正常)我正在尝试使用Numba来加快速度。

由于Numbas对理解的支持充其量是不稳定的,在最坏的情况下基本上不存在,我恢复使用普通的for循环,并希望Numba能够加快速度。

我的代码如下所示:

# This function works as intended.
@jit(nopython = True)
def distsqr(p1: Point, p2: Point) -> Numeric:
    """
    Euclidean distance squared between 2 points in R^3.
    """
    return ((p1 - p2) ** 2).sum()


@jit(nopython = True)
def kpp_init(points: PointArray, k: int, fixed_centroids: PointArray) -> PointList:
    """
    The KMeans++ initialization method using some given fixed centroids.
    """
    mobile_centroids = []
    all_centroids    = list(fixed_centroids)
    points_left      = list(points)

    if not len(all_centroids):
            new_centroid = numpy.random.choice(array(list(range(len(points_left)))))

            all_centroids    += [points_left[new_centroid]]
            mobile_centroids += [points_left[new_centroid]]

            del points_left[new_centroid]

    while len(all_centroids) < k:

            # For each point, distance to the nearest centroid, squared.
            # dxsqr = [min(distsqr(centroid, point) for centroid in all_centroids) for point in points_left]
            dxsqr = []
            for point in points_left:
                    dist = []
                    for centroid in all_centroids:
                            dist += [distsqr(point, centroid)]
                    dxsqr += [min(dist)]

            dxsqr_sum = numpy.sum(dxsqr)

            # Not working. Numba wants loops.
            # weights = [dxsqr_val / dxsqr_sum for dxsqr_val in dxsqr]
            weights = []
            for dxsqr_val in dxsqr:
                    weights += [dxsqr_val / dxsqr_sum]

            new_centroid = numpy.random.choice(array(list(range(len(points_left)))), p = weights)

            all_centroids    += [points_left[new_centroid]]
            mobile_centroids += [points_left[new_centroid]]

            del points_left[new_centroid]

            break
    return mobile_centroids

现在每当我尝试测试这个numba时,我都会有一个不那么可怕的错误:

File "E:\Tools\Python36\lib\site-packages\numba\typeinfer.py", line 888, in check_var
raise TypingError("Can't infer type of variable '%s': %s" % (var, tp))

numba.errors.TypingError: Failed at nopython (nopython frontend)
Can't infer type of variable 'dist': list(undefined)

这是一个非常容易理解的错误,但我不知道如何解决它。 所以我的问题是,是否有某种机制可以让numba成为dist的类型?它只会包含距离,所以我想类型将是list(float32)。如果没有这样的东西我怎么能让numba玩好并编译我的函数?

P.S。其他类型的注释基本上都是numpy.ndarray以某种形式存在,我怀疑它们是否有任何影响,我把它们放在那里给自己,但以防万一呢:

Numeric      = Union[numpy.array, numpy.dtype, int, float]      # Numeric

NumericArray = numpy.ndarray                                    # List[Numeric]

Point        = NumericArray                                     # Array[Numeric, Numeric, Numeric]
PointArray   = NumericArray                                     # List[Point]
PointList    = List[Point]

0 个答案:

没有答案