我有两个数据帧:
edges =
srcId dstId timestamp
1 4 1346564657
1 2 1345769687
2 4 1345769687
4 1 1345769687
vertices =
id name s_type
1 abc A
2 def B
3 rtf C
4 wrr D
我希望得到vertices
的子集,其中包含id
和srcId
中的dstId
未提及的edges
。
这是预期的输出:
sub_vertices =
id name s_type
3 rtf C
我该怎么做?
val sub_vertices = vertices
.join(edges, col("id") =!= col("srcId") && col("id") =!= col("dstId"), "left")
.na.fill(0)
.drop("srcId","dstId", "timestamp")
.dropDuplicates()
这是我目前的代码,但结果不正确。
答案 0 :(得分:2)
你差不多了,这里有一些你需要改变的事情
val sub_vertices = vertices
.join(edges, col("id") === col("srcId") || col("id") === col("dstId") , "left")
.filter($"srcId".isNull && $"dstId".isNull)
.drop("srcId","dstId", "timestamp")
输出:
+---+----+------+
|id |name|s_type|
+---+----+------+
|3 |rtf |C |
+---+----+------+
答案 1 :(得分:2)
您可以将srdId和dstIds收集到集中,然后将广播集作为用于顶点数据帧的过滤器
import org.apache.spark.sql.functions._
//collect all the srdId and dstId from edges dataframe into set and broadcast
val srdIdList = sc.broadcast(edges.select(collect_set("srcId").as("collectSrc"), collect_set("dstId").as("collectDst")).rdd.map(row => row.getAs[Seq[Int]](0) ++ row.getAs[Seq[Int]](1) toSet).collect()(0))
//using udf function remove all the rows that are in collected srdId and dstId in above step
def containsUdf = udf((id: Int) => !srdIdList.value.contains(id))
vertices.filter(containsUdf(col("id"))).show(false)
应该给你
+---+----+------+
|id |name|s_type|
+---+----+------+
|3 |rtf |C |
+---+----+------+