如何编写用户定义的聚合函数?

时间:2017-07-05 21:37:17

标签: java apache-spark apache-spark-sql

我正在尝试理解Java Spark文档。有一个名为 Untyped User Defined Aggregate Functions 的部分,其中包含一些我无法理解的示例代码。这是代码:

package org.apache.spark.examples.sql;

// $example on:untyped_custom_aggregation$
import java.util.ArrayList;
import java.util.List;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off:untyped_custom_aggregation$

public class JavaUserDefinedUntypedAggregation {

  // $example on:untyped_custom_aggregation$
  public static class MyAverage extends UserDefinedAggregateFunction {

    private StructType inputSchema;
    private StructType bufferSchema;

    public MyAverage() {
      List<StructField> inputFields = new ArrayList<>();
      inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
      inputSchema = DataTypes.createStructType(inputFields);

      List<StructField> bufferFields = new ArrayList<>();
      bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
      bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
      bufferSchema = DataTypes.createStructType(bufferFields);
    }
    // Data types of input arguments of this aggregate function
    public StructType inputSchema() {
      return inputSchema;
    }
    // Data types of values in the aggregation buffer
    public StructType bufferSchema() {
      return bufferSchema;
    }
    // The data type of the returned value
    public DataType dataType() {
      return DataTypes.DoubleType;
    }
    // Whether this function always returns the same output on the identical input
    public boolean deterministic() {
      return true;
    }
    // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
    // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
    // the opportunity to update its values. Note that arrays and maps inside the buffer are still
    // immutable.
    public void initialize(MutableAggregationBuffer buffer) {
      buffer.update(0, 0L);
      buffer.update(1, 0L);
    }
    // Updates the given aggregation buffer `buffer` with new input data from `input`
    public void update(MutableAggregationBuffer buffer, Row input) {
      if (!input.isNullAt(0)) {
        long updatedSum = buffer.getLong(0) + input.getLong(0);
        long updatedCount = buffer.getLong(1) + 1;
        buffer.update(0, updatedSum);
        buffer.update(1, updatedCount);
      }
    }
    // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
      long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
      long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
      buffer1.update(0, mergedSum);
      buffer1.update(1, mergedCount);
    }
    // Calculates the final result
    public Double evaluate(Row buffer) {
      return ((double) buffer.getLong(0)) / buffer.getLong(1);
    }
  }
  // $example off:untyped_custom_aggregation$

  public static void main(String[] args) {
    SparkSession spark = SparkSession
      .builder()
      .appName("Java Spark SQL user-defined DataFrames aggregation example")
      .getOrCreate();

    // $example on:untyped_custom_aggregation$
    // Register the function to access it
    spark.udf().register("myAverage", new MyAverage());

    Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
    df.createOrReplaceTempView("employees");
    df.show();
    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+

    Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
    result.show();
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+
    // $example off:untyped_custom_aggregation$

    spark.stop();
  }
}

我对上述代码的怀疑是:

  • 每当我想创建UDF时,我应该有initializeupdatemerge这些函数吗?
  • 变量inputSchemabufferSchema的重要性是什么?我很惊讶它们存在,因为它们从未用于创建任何DataFrame。它们应该存在于每个UDF中吗?如果是,那么他们应该是完全相同的名字吗?
  • 为什么inputSchemabufferSchema的吸气者未被命名为getInputSchema()getBufferSchema()?为什么没有这些变量的设定者?
  • 这里名为deterministic()的函数有什么意义?请给出一个调用此函数有用的方案。

一般来说,我想知道如何在Spark中编写用户定义的聚合函数。

1 个答案:

答案 0 :(得分:5)

  

每当我想创建一个UDF时,我应该有函数初始化,更新和合并

UDF 代表用户定义的函数,而方法initializeupdatemerge代表用户定义的聚合函数(又名 UDAF )。

UDF是一个与单行一起工作的函数(通常)产生一行(例如upper函数)。

UDAF是一个可以使用零行或多行来生成一行的函数(例如count聚合函数)。

对于用户定义的函数(UDF),您当然不必(并且将无法)具有函数initializeupdatemerge

使用udf functions中的任何一个来定义和注册UDF。

val myUpper = udf { (s: String) => s.toUpperCase }
  

如何在Spark中编写用户定义的聚合函数。

     

变量inputSchemabufferSchema的重要性是什么?

无耻插件:我一直在UserDefinedAggregateFunction — Contract for User-Defined Untyped Aggregate Functions (UDAFs)中掌握Sparking SQL书中的UDAF)

引用Untyped User-Defined Aggregate Functions

// Data types of input arguments of this aggregate function
def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)

// Data types of values in the aggregation buffer
def bufferSchema: StructType = {
  StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}

换句话说,inputSchema是您对输入的期望,而bufferSchema是您在进行聚合时暂时保留的内容。

  

为什么没有这些变量的设定者?

它们是由Spark管理的扩展点。

  

这里名为deterministic()的函数有什么意义?

引用Untyped User-Defined Aggregate Functions

// Whether this function always returns the same output on the identical input
def deterministic: Boolean = true
     

请说明调用此函数有用的方案。

这是我仍在努力的事情,因此今天无法回答。