聚合for spark数据帧中的模式(最常见元素)

时间:2016-10-25 01:26:06

标签: apache-spark aggregate average

在Spark中我正在使用一个库,我应该为其提供聚合,然后库会执行一系列连接/ groupby,并在最后调用聚合。我试图避免违反封装(尽管我可以,如果有必要),并且只使用聚合(传统的sum或min等)来调用此方法。

在这种情况下,我正在尝试运行模式,但是我不确定如何在聚合中运行。

1 个答案:

答案 0 :(得分:2)

这是一个Spark(2.1.0)UDAF来计算给定列的统计模式:

package org.anish.spark.mostcommonvalue

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scalaz.Scalaz._

/**
  * Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
  * Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
  * these values together is considered as mode.
  *
  * Usage:
  *
  * DataFrame / DataSet DSL
  * val mostCommonValue = new MostCommonValue
  * df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
  *
  * Spark SQL:
  * sqlContext.udf.register("mode", new MostCommonValue)
  * %sql
  * -- Use a group_by statement and call the UDAF.
  * select group_id, mode(id) from table group by group_id
  * 
  * Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
  *
  * Created by anish on 26/05/17.
  */
class MostCommonValue extends UserDefinedAggregateFunction {

  // This is the input fields for your aggregate function.
  // We use StringType, because Mode can also be meaningfully applied on nominal data
  override def inputSchema: StructType =
  StructType(StructField("value", StringType) :: Nil)

  // This is the internal fields you keep for computing your aggregate.
  // We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
  override def bufferSchema: StructType = StructType(
    StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
  )

  // This is the output type of your aggregation function.
  override def dataType: DataType = StringType

  override def deterministic: Boolean = true

  // This is the initial value for the buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Map[String, Long]()
  }

  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Map[String, Long]](0) |+| Map(input.getAs[String](0) -> 1L)
  }

  // This is how you merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Map[String, Long]](0) |+| buffer2.getAs[Map[String, Long]](0)
  }

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): String = {
    buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
  }
}

信用卡/源: https://gist.github.com/anish749/6a815ed281f538068a0d3a20ca9044fa