我在DataFrame的一列中有一个值列表,我想用它来过滤另一个更大的DataFrame,该数据框有2列要根据其匹配。
这里是一个例子。
df1 = sqlContext.createDataFrame(
[(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")],
("ID", "label1"))
df2 = sqlContext.createDataFrame(
[
(1, 2, "x"),
(2, 1, "y"),
(3, 1, "z"),
(4, 6, "s"),
(7, 2, "t"),
(8, 9, "z")
],
("ID1", "ID2", "label2")
)
我最后想要得到的是一个数据帧,其中包含df2
和ID1
都位于ID2
中的df1
条目。在这个例子中,看起来像这样;
+---+---+------+
|ID1|ID2| label|
+---+---+------+
| 1| 2| x|
| 2| 1| y|
| 3| 1| z|
+---+---+------+
我试图通过如下所示的联接来做到这一点;
df = df1.join(df2, (df1.ID == df2.ID1) | (df1.ID == df2.ID2))
但这会炸开我的桌子并给我
+---+------+---+---+------+
| ID|label1|ID1|ID2|label2|
+---+------+---+---+------+
| 1| a| 1| 2| x|
| 1| a| 2| 1| y|
| 1| a| 3| 1| z|
| 2| b| 1| 2| x|
| 2| b| 2| 1| y|
| 2| b| 7| 2| t|
| 3| c| 3| 1| z|
| 4| d| 4| 6| s|
+---+------+---+---+------+
然后
df = df1.join(df2, (df1.ID == df2.ID1) & (df1.ID == df2.ID2))
显然不是我想要的........对人们有帮助吗?
答案 0 :(得分:3)
我认为您可以使用您的初始join语句并进一步对DataFrame进行分组,并选择出现两次的行,因为ID1
中应包含ID2
和df1
。因此,它们应该在结果中出现两次,因为联接应该将df2
的行与df1
中的两个ID值重复。
结果语句如下:
from pyspark.sql.functions import col
df2.join(
df1,
[(df1.ID==df2.ID1)|(df1.ID==df2.ID2)],
how="left"
).groupBy("ID1","ID2","label").count().filter(col("count")==2).show()
结果是:
+---+---+-----+-----+
|ID1|ID2|label|count|
+---+---+-----+-----+
| 2 | 1 | y | 2 |
| 3 | 1 | z | 2 |
| 1 | 2 | x | 2 |
+---+---+-----+-----+
如果您不喜欢count列,可以在语句后附加select("ID1","ID2","label")
答案 1 :(得分:1)
这是使用spark-sql的另一种方法:
首先将您的DataFrames注册为表:
df1.createOrReplaceTempView('df1')
df2.createOrReplaceTempView('df2')
下一步运行以下查询:
df = sqlContext.sql(
"SELECT * FROM df2 WHERE ID1 IN (SELECT ID FROM df1) AND ID2 IN (SELECT ID FROM df1)"
)
df.show()
#+---+---+------+
#|ID1|ID2|label2|
#+---+---+------+
#| 3| 1| z|
#| 2| 1| y|
#| 1| 2| x|
#+---+---+------+
答案 2 :(得分:0)
可以在过滤数据后单独使用相交。这是使用核心Spark API的解决方案
>>> df1.show()
+---+------+
| ID|label1|
+---+------+
| 1| a|
| 2| b|
| 3| c|
| 4| d|
| 5| e|
+---+------+
>>> df2.show()
+---+---+------+
|ID1|ID2|label2|
+---+---+------+
| 1| 2| x|
| 2| 1| y|
| 3| 1| z|
| 4| 6| s|
| 7| 2| t|
| 8| 9| z|
+---+---+------+
>>> df3 = df1.join(df2, (df1.ID == df2.ID1)).select(df2['*'])
>>> df4 = df1.join(df2, (df1.ID == df2.ID2)).select(df2['*'])
>>> df3.intersect(df4).show()
+---+---+------+
|ID1|ID2|label2|
+---+---+------+
| 2| 1| y|
| 3| 1| z|
| 1| 2| x|
+---+---+------+