我有一个类似的Spark DataFrame
col_a | col_b | metric
------------------------
a1 | b1 | 100
a1 | b2 | 1
a1 | b3 | 3
a1 | b4 | 20
a2 | b5 | 4
a2 | b6 | 80
a2 | b7 | 20
a2 | b8 | 10
a2 | b9 | 20
a2 | b10 | 5
现在,在过滤掉顶部col_a
值之后,我想计算列n
上聚合的平均值和标准偏差。
例如,如果n=1
应根据以下过滤表计算平均值和标准偏差:
col_a | col_b | metric
------------------------
a1 | b2 | 1
a1 | b3 | 3
a1 | b4 | 20
a2 | b5 | 4
a2 | b7 | 20
a2 | b8 | 10
a2 | b9 | 20
a2 | b10 | 5
会导致
col_a | avg | std
-----------------------------
a1 | 8.0 | 8.5
a2 | 11.8 | 6.9
没有这种过滤,我将运行
df.groupby('col_a').agg(f.avg('metric'), f.stddev('metric'))
您知道如何添加这样的过滤器吗?
答案 0 :(得分:1)
您可以添加一个中间列以用于过滤。根据您想要处理领带的方式,可以使用documentation或pyspark.sql.functions.dense_rank
。
以下示例演示了区别:
from pyspark.sql import Window
import pyspark.sql.functions as f
w = Window.partitionBy("col_a").orderBy(f.desc("metric"))
df = df.select(
"*",
f.dense_rank().over(w).alias("metric_rank"),
f.row_number().over(w).alias("metric_row")
)
df.show()
#+-----+-----+------+-----------+----------+
#|col_a|col_b|metric|metric_rank|metric_row|
#+-----+-----+------+-----------+----------+
#| a2| b6| 80| 1| 1|
#| a2| b7| 20| 2| 2|
#| a2| b9| 20| 2| 3|
#| a2| b8| 10| 3| 4|
#| a2| b10| 5| 4| 5|
#| a2| b5| 4| 5| 6|
#| a1| b1| 100| 1| 1|
#| a1| b4| 20| 2| 2|
#| a1| b3| 3| 3| 3|
#| a1| b2| 1| 4| 4|
#+-----+-----+------+-----------+----------+
现在,仅根据metric_rank
或metric_row
进行过滤并进行汇总。在您的特定示例(其中n=1
)中没有区别:
n = 1
df.where(f.col("metric_rank") > n)\
.groupby('col_a')\
.agg(f.avg('metric'), f.stddev_pop('metric'))\
.show()
#+-----+-----------+------------------+
#|col_a|avg(metric)|stddev_pop(metric)|
#+-----+-----------+------------------+
#| a2| 11.8| 6.997142273814361|
#| a1| 8.0| 8.524474568362947|
#+-----+-----------+------------------+
(注意:您使用pyspark.sql.functions.row_number
会返回无偏样本标准偏差,而您显示的数字实际上是总体标准偏差,即stddev
)
但是,您可以看到,如果使用n=2
,则根据用于过滤器的两列中的哪一列会有不同的结果。查看metric
的两行中20
的{{1}}。如果要精确排除2行,则应使用a2
。如果要删除值在前2位的所有行,则需要使用metric_row
。