我正在尝试使用Apache Spark SQL在Java中创建用户定义的聚合函数(UDAF),该函数在完成时返回多个数组。我在网上搜索过,找不到任何关于如何做的例子或建议。
我能够返回单个数组,但无法弄清楚如何在evaluate()方法中以正确的格式获取数据以返回多个数组。
UDAF确实有效,因为我可以在evaluate()方法中打印出数组,我只是无法弄清楚如何将这些数组返回到调用代码(以下显示以供参考)。
UserDefinedAggregateFunction customUDAF = new CustomUDAF();
DataFrame resultingDataFrame = dataFrame.groupBy().agg(customUDAF.apply(dataFrame.col("long_col"), dataFrame.col("double_col"))).as("processed_data");
我在下面包含了整个自定义UDAF类,但关键方法是dataType()和evaluate方法(),它们首先显示。
非常感谢任何帮助或建议。谢谢。
public class CustomUDAF extends UserDefinedAggregateFunction {
@Override
public DataType dataType() {
// TODO: Is this the correct way to return 2 arrays?
return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
.add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
}
@Override
public Object evaluate(Row buffer) {
// Data conversion
List<Long> longList = new ArrayList<Long>(buffer.getList(0));
List<Double> dataList = new ArrayList<Double>(buffer.getList(1));
// Processing of data (omitted)
// TODO: How to get data into format needed to return 2 arrays?
return dataList;
}
@Override
public StructType inputSchema() {
return new StructType().add("long", DataTypes.LongType).add("data", DataTypes.DoubleType);
}
@Override
public StructType bufferSchema() {
return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
.add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, new ArrayList<Long>());
buffer.update(1, new ArrayList<Double>());
}
@Override
public void update(MutableAggregationBuffer buffer, Row row) {
ArrayList<Long> longList = new ArrayList<Long>(buffer.getList(0));
longList.add(row.getLong(0));
ArrayList<Double> dataList = new ArrayList<Double>(buffer.getList(1));
dataList.add(row.getDouble(1));
buffer.update(0, longList);
buffer.update(1, dataList);
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
ArrayList<Long> longList = new ArrayList<Long>(buffer1.getList(0));
longList.addAll(buffer2.getList(0));
ArrayList<Double> dataList = new ArrayList<Double>(buffer1.getList(1));
dataList.addAll(buffer2.getList(1));
buffer1.update(0, longList);
buffer1.update(1, dataList);
}
@Override
public boolean deterministic() {
return true;
}
}
更新:根据zero323的回答,我可以使用以下命令返回两个数组:
return new Tuple2<>(longArray, dataArray);
从中获取数据有点困难但涉及将DataFrame解构为Java Lists,然后将其构建回DataFrame。
答案 0 :(得分:7)
据我所知,返回一个元组应该就够了。在斯卡拉:
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}
object DummyUDAF extends UserDefinedAggregateFunction {
def inputSchema = new StructType().add("x", StringType)
def bufferSchema = new StructType()
.add("buff", ArrayType(LongType))
.add("buff2", ArrayType(DoubleType))
def dataType = new StructType()
.add("xs", ArrayType(LongType))
.add("ys", ArrayType(DoubleType))
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = {}
def update(buffer: MutableAggregationBuffer, input: Row) = {}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {}
def evaluate(buffer: Row) = (Array(1L, 2L, 3L), Array(1.0, 2.0, 3.0))
}
val df = sc.parallelize(Seq(("a", 1), ("b", 2))).toDF("k", "v")
df.select(DummyUDAF($"k")).show(1, false)
// +---------------------------------------------------+
// |(DummyUDAF$(k),mode=Complete,isDistinct=false) |
// +---------------------------------------------------+
// |[WrappedArray(1, 2, 3),WrappedArray(1.0, 2.0, 3.0)]|
// +---------------------------------------------------+