Python:最近邻居(或最接近匹配)过滤数据记录(元组列表)

时间:2012-03-19 10:50:59

标签: python

我正在尝试编写一个函数来过滤元组列表(模仿内存数据库),使用“最近邻居”或“最接近匹配”类型算法。

我想知道最好的(即大多数Pythonic)方式去做这件事。下面的示例代码有望说明我正在尝试做什么。

datarows = [(10,2.0,3.4,100),
            (11,2.0,5.4,120),
            (17,12.9,42,123)]

filter_record = (9,1.9,2.9,99) # record that we are seeking to retrieve from 'database' (or nearest match)
weights = (1,1,1,1) # weights to approportion to each field in the filter

def get_nearest_neighbour(data, criteria, weights):
    for each row in data:
        # calculate 'distance metric' (e.g. simple differencing) and multiply by relevant weight
    # determine the row which was either an exact match or was 'least dissimilar'
    # return the match (or nearest match)
    pass

if __name__ == '__main__':
    result = get_nearest_neighbour(datarow, filter_record, weights)
    print result

对于上面的代码段,输出应为:

(适用10,2.0,3.4,100)

因为它是传递给函数get_nearest_neighbour()的样本数据的“最近”。

我的问题是,实施 get_nearest_neighbour()的最佳方式是什么?为了简洁起见,假设我们只处理数值,并且我们使用的“距离度量”只是从当前行中输入数据的算术减法。

3 个答案:

答案 0 :(得分:4)

简单的开箱即用解决方案:

import math

def distance(row_a, row_b, weights):
    diffs = [math.fabs(a-b) for a,b in zip(row_a, row_b)]
    return sum([v*w for v,w in zip(diffs, weights)])

def get_nearest_neighbour(data, criteria, weights):
    def sort_func(row):
        return distance(row, criteria, weights)
    return min(data, key=sort_func)

如果您需要使用大型数据集,则应考虑切换到Numpy并使用Numpy的KDTree来查找最近的邻居。使用Numpy的优势在于它不仅使用更高级的算法,而且还实现了高度优化的LAPACK (Linear Algebra PACKage)

答案 1 :(得分:2)

关于naive-NN:

其他许多答案都提出了“天真的最近邻居”,这是一种O(N*d) - 每个查询算法(d是维度,在这种情况下似乎是常数,所以它是O(N) - 每个查询)。

虽然O(N) - 每个查询算法非常糟糕,但如果你的算法少于(例如)任何一个,那么你可以侥幸使用它:

  • 10个查询和100000个点
  • 100个查询和10000点
  • 1000个查询和1000个点
  • 10000个查询和100个点
  • 100000个查询和10个点

比天真的NN做得更好:

否则,您将需要使用以下列出的技术之一(尤其是最近邻数据结构):

特别是如果您打算多次运行您的程序。最有可能的库可用。如果您拥有#queries * #points的大量产品,否则不使用NN数据结构将花费太多时间。正如用户'dsign'在评论中指出的那样,你可以通过使用numpy库来挤出一个大的额外恒定速度因素。

但是,如果您可以使用简单易用的naive-NN,那么您应该使用它。

答案 2 :(得分:1)

在生成器上使用heapq.nlargest计算每条记录的距离*权重。

类似的东西:

heapq.nlargest(N, ((row, dist_function(row,criteria,weight)) for row in data), operator.itemgetter(1))