以下是我的dataframe
:
df = spark.createDataFrame([
(0, 1),
(0, 2),
(0, 5),
(1, 1),
(1, 2),
(1, 3),
(1, 5),
(2, 1),
(2, 2)
], ["id", "product"])
我需要进行groupBy
的{{1}}并收集所有物品,如下所示,但是我需要检查产品数量,如果数量少于2,则该数量不应该存在收集的物品。
例如,乘积3仅重复一次,即3的计数为1,小于2,因此在随后的数据框中不可用。看来我需要做两个id
:
预期输出:
groupBy
答案 0 :(得分:2)
我认为确实有两个groupBy
是一个不错的解决方案,您可以在第一个leftsemi
之后使用groupBy
联接来过滤初始DataFrame
。工作示例解决方案:
import pyspark.sql.functions as F
df = spark.createDataFrame([
(0, 1),
(0, 2),
(0, 5),
(1, 1),
(1, 2),
(1, 3),
(1, 5),
(2, 1),
(2, 2)
], ["id", "product"])
df = df\
.join(df.groupBy('product').count().filter(F.col('count')>=2),'product','leftsemi').distinct()\
.orderBy(F.col('id'),F.col('product'))\
.groupBy('id').agg(F.collect_list('product').alias('product'))
df.show()
orderBy
子句是可选的,除非您关心结果的顺序。输出:
+---+---------+
| id| product|
+---+---------+
| 0|[1, 2, 5]|
| 1|[1, 2, 5]|
| 2| [1, 2]|
+---+---------+
希望这会有所帮助!
答案 1 :(得分:0)
一种方法是使用Window
获取每个产品的计数,并将其用于filter
之前的groupBy()
:
import pyspark.sql.functions as f
from pyspark.sql import Window
df.withColumn('count', f.count('*').over(Window.partitionBy('product')))\
.where('count > 1')\
.groupBy('id')\
.agg(f.collect_list('product').alias('items'))\
.show()
#+---+---------+
#| id| items|
#+---+---------+
#| 0|[5, 1, 2]|
#| 1|[5, 1, 2]|
#| 2| [1, 2]|
#+---+---------+