两个点之间的Apache Spark距离使用squaredDistance

时间:2014-10-25 07:39:38

标签: scala apache-spark dbscan rdd

我有一个RDD的向量集合,其中每个向量代表一个带xy坐标的点。例如,文件如下:

1.1 1.2
6.1 4.8
0.1 0.1
9.0 9.0
9.1 9.1
0.4 2.1

我正在读它:

  def parseVector(line: String): Vector[Double] = {
    DenseVector(line.split(' ')).map(_.toDouble)
  }

  val lines = sc.textFile(inputFile)
  val points = lines.map(parseVector).cache()

另外,我有一个epsilon:

  val eps = 2.0

对于每个点我想找到它们在epsilon距离内的邻居。我这样做:

points.foreach(point =>
  // squaredDistance(point, ?) what should I write here?
)

如何循环所有点并为每个点找到它的邻居?可能使用map函数?

4 个答案:

答案 0 :(得分:2)

您可以执行以下操作:

val distanceBetweenPoints = points.cartesian(points)
    .filter{case (x,y) => (x!=y)} // remove the (x,x) diagonal
    .map{case (x,y) => ((x,y),distance(x,y))}
val pointsWithinEps = distanceBetweenPoints.filter{case ((x,y),distance) => distance <= eps)}

如果您之后不关心点之间的距离,也可以将滤波器中的距离计算结合起来。

答案 1 :(得分:1)

您可以使用SparkAI library并执行以下操作:

import org.aizook.scala.clustering.Spark_DBSCAN.DBSCAN val cluster:Dbscan = new Dbscan(3,5,data) cluster.predict((2000,(48.3,33.1)))

`val data: RDD(Long,(Double, Double)
eps = 3
minPts = 5`

答案 2 :(得分:1)

即使这个答案还没有被接受,我在这里作为一个通知,在github repo中提出的基本相同的已被接受的解决方案由于笛卡尔运算而具有可扩展性。 O(n^2)作为复杂性和庞大的数据集,这绝对是一个问题。

还有另一种解决方案,即通过Spark实现DBSCAN算法的另一种方法,可以在https://github.com/alitouka/spark_dbscan找到。该解决方案提出了一种不同的方法,将RDD数据集划分为&#34; box&#34;。以这种方式,近点可以仅是所考虑的点的相同框中的那些点以及远离连续分区的边界的小于epsil的那些。通过这种方式,复杂性降至O(m^2)mn/k,其中k为分区数。此外,还会执行其他优化(如果您需要更多详细信息,可以阅读代码,联系作者或向我询问)。

之前的实施有一些限制:仅支持欧几里德和曼哈顿测量,只能成功处理维度很少的数据集。为了解决这个问题,我创建了这个分支,旨在消除所有这些问题:https://github.com/speedymrk9/spark_dbscan/tree/distance-measure-independent。现在,似乎工作正常并且所有问题都得到了解决,尽管我正在进行测试,以确保它在发出拉取请求之前没有任何缺陷。

答案 3 :(得分:0)

@Bob那是因为(48.3,33.1)不适合群集,应归类为噪音。 我提交了SparkAI library的更新,它应该在预测适合噪声时返回-1

import org.aizook.scala.clustering.Spark_DBSCAN.Dbscan
val eps = 2
val minPts = 2
val data = sc.textFile("data.txt").map(_.split(" ")).map(p => (p(0).trim.toDouble, p(1).trim.toDouble)).zipWithUniqueId().map(x => (x._2,x._1)).cache;
val cluster:Dbscan = new Dbscan(eps,minPts,data)
cluster.predict((data.count+1,(9.0,10.0)))  // Should return 1 for cluster 1
cluster.predict((data.count+2,(2.0,2.0)))   // Should return 0 for cluster 0
cluster.predict((data.count+3,(15.0,23.0))) // Should return -1 for noise

包含您提交的数据样本的data.txt:

1.1 1.2
6.1 4.8
0.1 0.1
9.0 9.0
9.1 9.1
0.4 2.1