我正在尝试理解Java Spark文档。有一个名为 Untyped User Defined Aggregate Functions 的部分,其中包含一些我无法理解的示例代码。这是代码:
package org.apache.spark.examples.sql;
// $example on:untyped_custom_aggregation$
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off:untyped_custom_aggregation$
public class JavaUserDefinedUntypedAggregation {
// $example on:untyped_custom_aggregation$
public static class MyAverage extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public MyAverage() {
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
inputSchema = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
// Data types of input arguments of this aggregate function
public StructType inputSchema() {
return inputSchema;
}
// Data types of values in the aggregation buffer
public StructType bufferSchema() {
return bufferSchema;
}
// The data type of the returned value
public DataType dataType() {
return DataTypes.DoubleType;
}
// Whether this function always returns the same output on the identical input
public boolean deterministic() {
return true;
}
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0L);
buffer.update(1, 0L);
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
long updatedSum = buffer.getLong(0) + input.getLong(0);
long updatedCount = buffer.getLong(1) + 1;
buffer.update(0, updatedSum);
buffer.update(1, updatedCount);
}
}
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
buffer1.update(0, mergedSum);
buffer1.update(1, mergedCount);
}
// Calculates the final result
public Double evaluate(Row buffer) {
return ((double) buffer.getLong(0)) / buffer.getLong(1);
}
}
// $example off:untyped_custom_aggregation$
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("Java Spark SQL user-defined DataFrames aggregation example")
.getOrCreate();
// $example on:untyped_custom_aggregation$
// Register the function to access it
spark.udf().register("myAverage", new MyAverage());
Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
// $example off:untyped_custom_aggregation$
spark.stop();
}
}
我对上述代码的怀疑是:
initialize
,update
和merge
这些函数吗?inputSchema
和bufferSchema
的重要性是什么?我很惊讶它们存在,因为它们从未用于创建任何DataFrame。它们应该存在于每个UDF中吗?如果是,那么他们应该是完全相同的名字吗?inputSchema
和bufferSchema
的吸气者未被命名为getInputSchema()
和getBufferSchema()
?为什么没有这些变量的设定者?deterministic()
的函数有什么意义?请给出一个调用此函数有用的方案。一般来说,我想知道如何在Spark中编写用户定义的聚合函数。
答案 0 :(得分:5)
每当我想创建一个UDF时,我应该有函数初始化,更新和合并
UDF 代表用户定义的函数,而方法initialize
,update
和merge
代表用户定义的聚合函数(又名 UDAF )。
UDF是一个与单行一起工作的函数(通常)产生一行(例如upper
函数)。
UDAF是一个可以使用零行或多行来生成一行的函数(例如count
聚合函数)。
对于用户定义的函数(UDF),您当然不必(并且将无法)具有函数initialize
,update
和merge
。
使用udf
functions中的任何一个来定义和注册UDF。
val myUpper = udf { (s: String) => s.toUpperCase }
如何在Spark中编写用户定义的聚合函数。
变量
inputSchema
和bufferSchema
的重要性是什么?
(无耻插件:我一直在UserDefinedAggregateFunction — Contract for User-Defined Untyped Aggregate Functions (UDAFs)中掌握Sparking SQL书中的UDAF)
引用Untyped User-Defined Aggregate Functions:
// Data types of input arguments of this aggregate function def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil) // Data types of values in the aggregation buffer def bufferSchema: StructType = { StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) }
换句话说,inputSchema
是您对输入的期望,而bufferSchema
是您在进行聚合时暂时保留的内容。
为什么没有这些变量的设定者?
它们是由Spark管理的扩展点。
这里名为
deterministic()
的函数有什么意义?
引用Untyped User-Defined Aggregate Functions:
// Whether this function always returns the same output on the identical input def deterministic: Boolean = true
请说明调用此函数有用的方案。
这是我仍在努力的事情,因此今天无法回答。