过滤掉前n个值后,计算avg和stddev

时间:2019-07-16 17:04:29

标签: apache-spark pyspark pyspark-sql

我有一个类似的Spark DataFrame

col_a |  col_b |  metric
------------------------
a1    |  b1    |  100
a1    |  b2    |  1
a1    |  b3    |  3
a1    |  b4    |  20
a2    |  b5    |  4
a2    |  b6    |  80
a2    |  b7    |  20
a2    |  b8    |  10
a2    |  b9    |  20
a2    |  b10   |  5

现在,在过滤掉顶部col_a值之后,我想计算列n上聚合的平均值和标准偏差。

例如,如果n=1应根据以下过滤表计算平均值和标准偏差:

col_a |  col_b |  metric
------------------------
a1    |  b2    |  1
a1    |  b3    |  3
a1    |  b4    |  20
a2    |  b5    |  4
a2    |  b7    |  20
a2    |  b8    |  10
a2    |  b9    |  20
a2    |  b10   |  5

会导致

col_a      | avg   | std
-----------------------------
a1         |  8.0  |  8.5
a2         |  11.8 |  6.9

没有这种过滤,我将运行

df.groupby('col_a').agg(f.avg('metric'), f.stddev('metric'))

您知道如何添加这样的过滤器吗?

1 个答案:

答案 0 :(得分:1)

 

您可以添加一个中间列以用于过滤。根据您想要处理领带的方式,可以使用documentationpyspark.sql.functions.dense_rank

以下示例演示了区别:

from pyspark.sql import Window
import pyspark.sql.functions as f

w = Window.partitionBy("col_a").orderBy(f.desc("metric"))
df = df.select(
    "*",
    f.dense_rank().over(w).alias("metric_rank"),
    f.row_number().over(w).alias("metric_row")
)

df.show()
#+-----+-----+------+-----------+----------+
#|col_a|col_b|metric|metric_rank|metric_row|
#+-----+-----+------+-----------+----------+
#|   a2|   b6|    80|          1|         1|
#|   a2|   b7|    20|          2|         2|
#|   a2|   b9|    20|          2|         3|
#|   a2|   b8|    10|          3|         4|
#|   a2|  b10|     5|          4|         5|
#|   a2|   b5|     4|          5|         6|
#|   a1|   b1|   100|          1|         1|
#|   a1|   b4|    20|          2|         2|
#|   a1|   b3|     3|          3|         3|
#|   a1|   b2|     1|          4|         4|
#+-----+-----+------+-----------+----------+

现在,仅根据metric_rankmetric_row进行过滤并进行汇总。在您的特定示例(其中n=1)中没有区别:

n = 1
df.where(f.col("metric_rank") > n)\
    .groupby('col_a')\
    .agg(f.avg('metric'), f.stddev_pop('metric'))\
    .show()
#+-----+-----------+------------------+
#|col_a|avg(metric)|stddev_pop(metric)|
#+-----+-----------+------------------+
#|   a2|       11.8| 6.997142273814361|
#|   a1|        8.0| 8.524474568362947|
#+-----+-----------+------------------+

(注意:您使用pyspark.sql.functions.row_number会返回无偏样本标准偏差,而您显示的数字实际上是总体标准偏差,即stddev

但是,您可以看到,如果使用n=2,则根据用于过滤器的两列中的哪一列会有不同的结果。查看metric的两行中20的{​​{1}}。如果要精确排除2行,则应使用a2。如果要删除值在前2位的所有行,则需要使用metric_row