我有一个Spark sql数据帧,包含ID
列和n
"数据"列,即
id | dat1 | dat2 | ... | datn
id
列是唯一确定的,而查看dat1 ... datn
则可能存在重复。
我的目标是找到这些重复项的id
。
到目前为止我的方法:
使用groupBy
获取重复的行:
dup_df = df.groupBy(df.columns[1:]).count().filter('count > 1')
将dup_df
加入整个df
以获取重复的行,包括 id
:
df.join(dup_df, df.columns[1:])
我确信这基本上是正确的,但由于dat1 ... datn
列包含null
值,因此失败了。
要对join
值进行null
,我找到.e.g this SO post。但这需要构建一个巨大的字符串连接条件"。
因此我的问题:
joins
值进行null
?id
?BTW:我使用的是Spark 2.1.0和Python 3.5.3
答案 0 :(得分:12)
如果每个组的数量ids
相对较小,您可以groupBy
和collect_list
。必需的进口
from pyspark.sql.functions import collect_list, size
示例数据:
df = sc.parallelize([
(1, "a", "b", 3),
(2, None, "f", None),
(3, "g", "h", 4),
(4, None, "f", None),
(5, "a", "b", 3)
]).toDF(["id"])
查询:
(df
.groupBy(df.columns[1:])
.agg(collect_list("id").alias("ids"))
.where(size("ids") > 1))
结果:
+----+---+----+------+
| _2| _3| _4| ids|
+----+---+----+------+
|null| f|null|[2, 4]|
| a| b| 3|[1, 5]|
+----+---+----+------+
您可以将explode
两次(或使用udf
)应用于与join
返回的输出等效的输出。
您还可以使用每组最小id
来识别群组。一些额外的进口:
from pyspark.sql.window import Window
from pyspark.sql.functions import col, count, min
窗口定义:
w = Window.partitionBy(df.columns[1:])
查询:
(df
.select(
"*",
count("*").over(w).alias("_cnt"),
min("id").over(w).alias("group"))
.where(col("_cnt") > 1))
结果:
+---+----+---+----+----+-----+
| id| _2| _3| _4|_cnt|group|
+---+----+---+----+----+-----+
| 2|null| f|null| 2| 2|
| 4|null| f|null| 2| 2|
| 1| a| b| 3| 2| 1|
| 5| a| b| 3| 2| 1|
+---+----+---+----+----+-----+
您可以进一步使用group
列进行自我加入。