如何根据条件从DataFrame中检索行的子集?

时间:2018-05-08 08:53:39

标签: scala apache-spark apache-spark-sql

我有两个数据帧:

edges =
   srcId    dstId    timestamp
   1        4        1346564657
   1        2        1345769687
   2        4        1345769687
   4        1        1345769687
vertices =
   id   name   s_type
   1    abc    A
   2    def    B
   3    rtf    C
   4    wrr    D

我希望得到vertices的子集,其中包含idsrcId中的dstId未提及的edges

这是预期的输出:

sub_vertices =
   id   name   s_type
   3    rtf    C

我该怎么做?

val sub_vertices = vertices
  .join(edges, col("id") =!= col("srcId") && col("id") =!= col("dstId"), "left")
  .na.fill(0)
  .drop("srcId","dstId", "timestamp")
  .dropDuplicates()

这是我目前的代码,但结果不正确。

2 个答案:

答案 0 :(得分:2)

你差不多了,这里有一些你需要改变的事情

val sub_vertices = vertices
  .join(edges, col("id") === col("srcId") || col("id") === col("dstId") , "left")
  .filter($"srcId".isNull && $"dstId".isNull)
  .drop("srcId","dstId", "timestamp")

输出:

+---+----+------+
|id |name|s_type|
+---+----+------+
|3  |rtf |C     |
+---+----+------+

答案 1 :(得分:2)

您可以将srdId和dstIds收集到集中,然后将广播集作为用于顶点数据帧的过滤器

import org.apache.spark.sql.functions._
//collect all the srdId and dstId from edges dataframe into set and broadcast
val srdIdList = sc.broadcast(edges.select(collect_set("srcId").as("collectSrc"), collect_set("dstId").as("collectDst")).rdd.map(row => row.getAs[Seq[Int]](0) ++ row.getAs[Seq[Int]](1) toSet).collect()(0))

//using udf function remove all the rows that are in collected srdId and dstId in above step
def containsUdf = udf((id: Int) => !srdIdList.value.contains(id))
vertices.filter(containsUdf(col("id"))).show(false)

应该给你

+---+----+------+
|id |name|s_type|
+---+----+------+
|3  |rtf |C     |
+---+----+------+