将一列中的所有行与同一列中的所有其他行进行比较(特殊查询)

时间:2019-05-16 02:10:29

标签: apache-spark apache-spark-sql

在此查询中,我得到了一个包含5d欧式点列的数据框(存储为双精度数组)。我需要找到所有可用的平均距离。也就是说,对于每个点a,我都计算出数据框中到另一个点b的距离,并找到这些距离的平均值。请注意,我不需要任何数学方法或简化此问题。数据框有两列,unique_id和vector。

我可以执行查询,但是只能通过以下方式针对1点进行查询。 UDF距离计算存储的数组(即包装的数组)与给定数组之间的距离。但是,很明显,这种方法仅适用于一点。另外,我尝试将数据集传递给静态函数。但是每次执行此操作时,我都会得到一个“无效树:空”,即对象进入函数后立即变为空...最后,我想到了制作UDAF,但我意识到这并不是一个UDAF。适当的集合函数。任何帮助,将不胜感激!

(注意:此代码是用Java编写的,但与其他语言没什么不同)

        long equal = 2;
        WrappedArray<Double> num = (WrappedArray<Double> spo.select("vectors")
       .filter(col("unique_id").equalTo(equal)).first().get(0);
        List<Double> frameList =  scala.collection.JavaConverters.seqAsJavaList(num);

        double[] array_answer = frameList.stream().mapToDouble(Double::doubleValue).toArray();
        UserDefinedFunction compare = udf(
                (WrappedArray<Double> array)  -> cosine_distance(array, array_answer),  DataTypes.DoubleType
        );
        double answer = (double) spo.select("vectors").filter(col("unique_id").notEqual(equal))
            .withColumn("calc", compare.apply(col("vectors")))
            .select(avg("calc")).first().get(0);
        System.out.println(answer);

1 个答案:

答案 0 :(得分:0)

可以使用crossJoin完成。这是Scala中的1d(伪)代码:

val df = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("unique_id", "vector")

df.select($"unique_id" as "id0", $"vector" as "vector0")
  .crossJoin(df.select($"unique_id" as "id1", $"vector" as "vector1"))
  .filter($"id0" =!= $"id1")
  .groupBy($"id0" as "unique_id")
  .agg(avg(
    abs($"vector0" - $"vector1") /*  use actual distance here */ ) as "mean_distance")
  .show()
+---------+-------------+
|unique_id|mean_distance|
+---------+-------------+
|        c|          1.5|
|        b|          1.0|
|        a|          1.5|
+---------+-------------+