我想计算从分组的Spark数据帧的列中有多少记录是真的,但我不知道如何在python中做到这一点。例如,我的数据包含region
,salary
和IsUnemployed
列,其中IsUnemployed
为布尔值。我想看看每个地区有多少失业人员。我知道我们可以执行filter
然后groupby
,但我想在下面同时生成两个聚合
from pyspark.sql import functions as F
data.groupby("Region").agg(F.avg("Salary"), F.count("IsUnemployed"))
答案 0 :(得分:18)
最简单的解决方案可能是普通CAST
(C样式,其中TRUE
- > 1,FALSE
- > 0)SUM
:
(data
.groupby("Region")
.agg(F.avg("Salary"), F.sum(F.col("IsUnemployed").cast("long"))))
CASE WHEN
COUNT
(data
.groupby("Region")
.agg(
F.avg("Salary"),
F.count(F.when(F.col("IsUnemployed"), F.col("IsUnemployed")))))
更具普遍性和惯用解决方案:
path to me
但这显然是一种矫枉过正。