有没有办法使这个Python kNN功能更有效?

时间:2014-10-16 08:32:25

标签: python numpy machine-learning distance knn

在遇到MATLAB麻烦后,我决定尝试使用Python:

我写了一个函数,当样本属于我自己的类时,使用我自己的距离函数来计算kNN:

def closestK(sample, otherSamples, distFunc, k):
"Returns the closest k samples to sample based on distFunc"
    n = len(otherSamples)
    d = [distFunc(sample, otherSamples[i]) for i in range(0,n)]
    idx  = sorted(range(0,len(d)), key=lambda k: d[k])
    return idx[1:(k+1)]

def kNN(samples, distFunc, k):
    return [[closestK(samples[i], samples, distFunc, k)] for i in range(len(samples))]

这是距离函数:

@staticmethod    
def distanceRepr(c1, c2):
    r1 = c1.repr
    r2 = c2.repr
    # because cdist needs 2D array
    if r1.ndim == 1:
        r1 = np.vstack([r1,r1])
    if r2.ndim == 1:
        r2 = np.vstack([r2,r2])

    return scipy.spatial.distance.cdist(r1, r2, 'euclidean').min()

但与#34;普通"相比,它的效果仍然非常慢。 kNN功能,即使使用" brute"算法。我做错了吗?

更新

我添加了类的构造函数。属性 repr 包含一组向量(从1到无),距离计算为两组repr之间的最小欧氏距离。

class myCluster:
    def __init__(self, index = -1, P = np.array([])):
        if index ==-1 :
            self.repr = np.array([])
            self.IDs = np.array([])
            self.n = 0
            self.center = np.array([])
        else:
            self.repr = np.array(P)
            self.IDs = np.array(index)
            self.n = 1
            self.center = np.array(P)

和其他相关代码(X是一个矩阵,其行是样本,列是变量):

level = [myCluster(i, X[i,:]) for i in range(0,n)]
kNN(level, myCluster.distanceRepr, 3)

更新2

我做了一些测量,大部分时间都是

d = [distFunc(sample, otherSamples[i]) for i in range(0,n)]

所以distFunc有一些东西。当我改变它以返回

np.linalg.norm(c1.repr-c2.repr)

即。 "正常"矢量计算,排序,运行时间保持不变。所以问题在于调用这个函数。使用类是否有意义将运行时间改变60倍?

2 个答案:

答案 0 :(得分:2)

你只是遇到了Python的缓慢(或者更确切地说,我猜应该说CPython解释器)。来自wikipedia

  

NumPy的目标是Python的CPython参考实现,它是一个非优化的bytecode编译器/解释器。为此版本的Python编写的数学算法通常比compiled等效运行速度慢得多。 NumPy试图通过提供在数组上高效运行的多维数组和函数以及运算符来解决这个问题。因此,任何可以主要表示为对数组和矩阵的操作的算法的运行速度几乎与等效的C代码一样快。

来自Scipy FAQ:

  

Python的列表是高效的通用容器。它们支持(相当)有效的插入,删除,追加和连接,Python的列表推导使它们易于构造和操作。但是,它们有一定的局限性:它们不支持元素化加法和乘法等“向量化”操作,并且它们可以包含不同类型的对象这一事实意味着Python必须存储每个元素的类型信息,并且必须执行类型调度代码在对每个元素进行操作时。这也意味着很少有列表操作可以通过有效的C循环来执行 - 每次迭代都需要进行类型检查和其他Python API簿记。

注意这不仅仅涉及Python;有关更多背景,请参阅SO上的thisthis question

由于动态类型系统和解释器的开销,如果无法利用各种编译的C和Fortran库(例如Numpy),Python对高性能数字运算的用处就会大打折扣。 )。此外,还有像Numba和PyPy这样的JIT编译器,它们试图让Python代码更接近静态类型编译代码的速度。

底线:相对于您正在卸载到快速C代码的工作,您在普通Python中做了很多工作。我想你需要采用更像“面向数组”的编码风格而不是面向对象来实现Numpy的良好性能(MATLAB在这方面是一个非常相似的故事)。另一方面,如果你使用更有效的算法(参见Ara的答案),那么Python的缓慢可能不是一个问题。

答案 1 :(得分:0)

以下是我能想到的要点:

  • 每次调用nearestK时,你计算一个样本和每个样本之间的距离,所以你计算每个样本之间的距离两次(一次距离(a,b)然后距离(b,a)),这可以通过一劳永逸地计算
  • 你重新计算r(可能涉及一个成本高的vstack)2 *(n - 1)次,其中n是len(样本),你也可以一劳永逸地计算它(并将它存储为myCluster的一个属性?)
  • 您在完整列表中计算排序,而您只需要top-k(无需在第k个元素之后排序)
  • 计算你的集合点之间的最小距离,你创建一个包含每个距离的矩阵然后采取它的最小值:你当然可以做得更好

我的建议是使用 insert 方法实现一个top-k类,只有当你比当前第k个元素更好(并删除它)并修改myCluster才能包含它河那么您的代码可能看起来像

kNN = {i : TopK() for i in xrange(len(samples))}
for i, sample1 in enumerate(samples):
    for j, sample2 in enumerate(samples[:i]):
        dist = distanceRepr(sample1, sample2)
        kNN[i].insert(j, -dist)
        kNN[j].insert(i, -dist)
return kNN

这是一个可能的实现ok TopK:

import heapq

class TopK:
    def __init__(self, k):
        self.k = k
        self.content = []

    def insert (self, key, score):
        if len(self.content) < self.k:
            heapq.heappush(self.content, (score, key))
        else:
            heapq.heappushpop(self.content, (score, key))

    def get_keys(self):
        return [elem[1] for elem in self.content]

对于distanceRepr,您可以使用以下内容:

import scipy.spatial

def distanceRepr(set0 ,set1):
    if len(set0) < len(set1):
        min_set = set0
        max_set = set1
    else:
        min_set = set1
        max_set = set0
    if len(min_set) == 0:
        raise Exception("Empty set")

    min_dist = scipy.inf
    tree = scipy.spatial.cKDTree(max_set)

    for point in min_set:
        distance, _ = tree.query(point, 1, 0., 2, min_dist)
        if min_dist > distance:
            min_dist = min(min_dist, distance)

    return min_dist

对于中型和大型条目,它会比您当前的方法更快(比如说sample1和2的大小> 5k),它的内存使用量也会小得多,允许它使用大样本(其中{ {1}}只是内存不足。