我试图为scala中的每个组实现80%的修剪平均值以摆脱异常值。但只有当该组中的记录数量至少超过10时才必须应用。
示例,
val sales = Seq(
("Warsaw", 2016, 100),
("Warsaw", 2017, 200),
("Boston", 2015, 50),
("Boston", 2016, 150),
("Toronto", 2017, 50)
).toDF("city", "year", "amount")
所以在这个数据集中,如果我正在做这个组,
val groupByCityAndYear = sales
.groupBy("city", "year").count()
.agg(avg($"amount").as("avg_amount"))
所以在这种情况下,如果计数超过10,那么应该删除异常值(可以修剪80%的意思),否则直接平均($“金额”)。我怎样才能实现这一目标?
以下是我得到的修剪均值的正确解释,以解释这种情况,
考虑修剪的意思是什么:在原型情况下,您首先按递增顺序对数据进行排序。然后从底部算起修剪百分比并丢弃这些值。例如,10%的修剪平均值是常见的;在这种情况下,您从最低值开始计算,直到您通过集合中所有数据的10%。低于该标记的值被搁置。同样地,您从最高值开始倒计时,直到您通过修剪百分比,并将所有值设置为大于此值。你现在剩下80%的中间人。你取其平均值,这就是你的10%修剪平均值
答案 0 :(得分:1)
这可以通过窗口功能完成,但价格昂贵:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy("city", "year").orderBy("amount")
sales
.withColumn("rn", row_number().over(w))
.withColumn("count", count("*").over(w))
.groupBy("city", "year")
.agg(avg(when(
($"count" < 10) or ($"rn" between($"count" * 0.1, $"count" * 0.9)),
$"amount"
)) as "avg_amount")