用NumPY改善KNN效率的线性搜索

时间:2016-09-12 19:39:13

标签: python numpy machine-learning

我正在尝试从训练集中的每个点计算测试集中每个点的距离:

这就是我现在的循环:

 for x in testingSet
    for y in trainingSet
        print numpy.linalg.norm(x-y)

其中testingSet和trainingSet是numpy数组,其中两个集合中的每一行都包含一个示例的要素数据。

然而,由于我的数据集较大(测试设置为3000,训练集为~10,000),因此运行速度非常慢,耗时超过10分钟。这与我的方法有关,还是我错误地使用numPY?

1 个答案:

答案 0 :(得分:3)

这是因为你天真地遍历你的数据,并且循环在python中很慢。相反,使用sklearn pairwise distance functions,甚至更好 - 使用sklearn efficient nearest neighbour搜索(如BallTree或KDTree)。如果您不想使用sklearn,还有module in scipy。最后你可以做“矩阵技巧”来计算这个,因为

|| x - y ||^2 = <x-y, x-y> = <x,x> + <y,y> - 2<x,y>

你可以做(​​假设你的数据是矩阵形式,给定为X和Y):

X2 = (X**2).sum(axis=1).reshape((-1, 1))
Y2 = (Y**2).sum(axis=1).reshape((1, -1))
distances = np.sqrt(X2 + Y2 - 2*X.dot(Y.T))