如何用Spark查找最近的10亿条记录?

时间:2016-05-03 18:46:20

标签: apache-spark pyspark spark-dataframe nearest-neighbor euclidean-distance

鉴于包含以下信息的10亿条记录:

    ID  x1  x2  x3  ... x100
    1   0.1  0.12  1.3  ... -2.00
    2   -1   1.2    2   ... 3
    ...

对于上面的每个ID,我想找到前10个最接近的ID,基于它们的向量的欧几里德距离(x1,x2,...,x100)。

计算此数据的最佳方式是什么?

3 个答案:

答案 0 :(得分:7)

碰巧,我有一个解决方案,包括将sklearn与Spark结合起来:https://adventuresindatascience.wordpress.com/2016/04/02/integrating-spark-with-scikit-learn-visualizing-eigenvectors-and-fun/

它的要点是:

  • 集中使用sklearn的k-NN fit()方法
  • 然后分布式地使用sklearn的k-NN kneighbors()方法

答案 1 :(得分:5)

对所有记录执行所有记录的暴力比较是一场失败的战斗。我的建议是采用现成的k-Nearest Neighbor算法实现,例如scikit-learn提供的算法,然后广播得到的索引和距离数组,然后再进一步。

这种情况下的步骤是:

1-像Bryce建议的那样对这些特征进行矢量化,让你的矢量化方法返回一个浮动列表(或numpy数组),其中包含与你的特征一样多的元素

2-适合你的scikit-learn nn到你的数据:

nbrs = NearestNeighbors(n_neighbors=10, algorithm='auto').fit(vectorized_data)

3-在您的矢量化数据上运行经过训练的算法(训练和查询数据在您的情况下是相同的)

distances, indices = nbrs.kneighbors(qpa)

步骤2和3将在您的pyspark节点上运行,在这种情况下不可并行化。您需要在此节点上有足够的内存。在我的情况下,有150万条记录和4个功能,需要一两秒钟。

在我们为火花获得NN的良好实施之前,我想我们必须坚持这些解决方法。如果您想尝试新的内容,请转到http://spark-packages.org/package/saurfang/spark-knn

答案 2 :(得分:1)

你还没有提供很多细节,但我对这个问题采取的一般方法是:

  1. 将记录转换为类似LabeledPoint的数据结构,其中(ID,x1..x100)为标签和功能
  2. 映射每条记录并将该记录与所有其他记录(此处有大量优化空间)进行比较
  3. 创建一些截止逻辑,以便一旦开始比较ID = 5且ID = 1,就会中断计算,因为您已将ID = 1与ID = 5进行比较
  4. 一些缩小步骤以获取{id_pair: [1,5], distance: 123}
  5. 等数据结构
  6. 找到每条记录的10个最近邻居的另一个地图步骤
  7. 您已经确定了pyspark,我通常使用scala进行此类工作,但每个步骤的一些伪代码可能如下所示:

    # 1. vectorize the features
    def vectorize_raw_data(record)
        arr_of_features = record[1..99]
        LabeledPoint( record[0] , arr_of_features)
    
    # 2,3 + 4 map over each record for comparison
    broadcast_var = [] 
    def calc_distance(record, comparison)
        # here you want to keep a broadcast variable with a list or dictionary of
        # already compared IDs and break if the key pair already exists
        # then, calc the euclidean distance by mapping over the features of
        # the record and subtracting the values then squaring the result, keeping 
        # a running sum of those squares and square rooting that sum
        return {"id_pair" : [1,5], "distance" : 123}    
    
    for record in allRecords:
      for comparison in allRecords:
        broadcast_var.append( calc_distance(record, comparison) )
    
    # 5. map for 10 closest neighbors
    
    def closest_neighbors(record, n=10)
         broadcast_var.filter(x => x.id_pair.include?(record.id) ).takeOrdered(n, distance)
    

    伪代码很糟糕,但我认为它传达了意图。当你将所有记录与所有其他记录进行比较时,这里会有很多改组和排序。恕我直言,你想将密钥对/距离存储在一个中心位置(就像一个广告变量,虽然这是危险的,但它会更新),以减少你执行的欧几里德总距离计算。