在scala中应用条件修剪均值

时间:2018-05-12 06:55:18

标签: scala apache-spark apache-spark-sql spark-dataframe amazon-emr

我试图为scala中的每个组实现80%的修剪平均值以摆脱异常值。但只有当该组中的记录数量至少超过10时才必须应用。

示例,

val sales = Seq(
  ("Warsaw", 2016, 100),
  ("Warsaw", 2017, 200),
  ("Boston", 2015, 50),
  ("Boston", 2016, 150),
  ("Toronto", 2017, 50)
).toDF("city", "year", "amount")

所以在这个数据集中,如果我正在做这个组,

val groupByCityAndYear = sales
  .groupBy("city", "year").count() 
  .agg(avg($"amount").as("avg_amount"))

所以在这种情况下,如果计数超过10,那么应该删除异常值(可以修剪80%的意思),否则直接平均($“金额”)。我怎样才能实现这一目标?

以下是我得到的修剪均值的正确解释,以解释这种情况,

考虑修剪的意思是什么:在原型情况下,您首先按递增顺序对数据进行排序。然后从底部算起修剪百分比并丢弃这些值。例如,10%的修剪平均值是常见的;在这种情况下,您从最低值开始计算,直到您通过集合中所有数据的10%。低于该标记的值被搁置。同样地,您从最高值开始倒计时,直到您通过修剪百分比,并将所有值设置为大于此值。你现在剩下80%的中间人。你取其平均值,这就是你的10%修剪平均值

1 个答案:

答案 0 :(得分:1)

这可以通过窗口功能完成,但价格昂贵:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("city", "year").orderBy("amount")

sales
  .withColumn("rn", row_number().over(w))
  .withColumn("count", count("*").over(w))
  .groupBy("city", "year")
  .agg(avg(when(
    ($"count" < 10) or ($"rn" between($"count" * 0.1, $"count" * 0.9)), 
    $"amount"
  )) as "avg_amount")