我想将数组作为输入架构传递给UDAF。
我给出的示例非常简单,它仅对2个向量求和。实际上,我的用例更加复杂,我需要使用UDAF。
import sc.implicits._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._
val df = Seq(
(1, Array(10.2, 12.3, 11.2)),
(1, Array(11.2, 12.6, 10.8)),
(2, Array(12.1, 11.2, 10.1)),
(2, Array(10.1, 16.0, 9.3))
).toDF("siteId", "bidRevenue")
class BidAggregatorBySiteId() extends UserDefinedAggregateFunction {
def inputSchema: StructType = StructType(Array(StructField("bidRevenue", ArrayType(DoubleType))))
def bufferSchema = StructType(Array(StructField("sumArray", ArrayType(DoubleType))))
def dataType: DataType = ArrayType(DoubleType)
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, Array(0.0, 0.0, 0.0))
}
def update(buffer: MutableAggregationBuffer, input: Row) = {
val seqBuffer = buffer(0).asInstanceOf[IndexedSeq[Double]]
val seqInput = input(0).asInstanceOf[IndexedSeq[Double]]
buffer(0) = seqBuffer.zip(seqInput).map{ case (x, y) => x + y }
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val seqBuffer1 = buffer1(0).asInstanceOf[IndexedSeq[Double]]
val seqBuffer2 = buffer2(0).asInstanceOf[IndexedSeq[Double]]
buffer1(0) = seqBuffer1.zip(seqBuffer2).map{ case (x, y) => x + y }
}
def evaluate(buffer: Row) = {
buffer
}
}
val fun = new BidAggregatorBySiteId()
df.select($"siteId", $"bidRevenue" cast(ArrayType(DoubleType)))
.groupBy("siteId").agg(fun($"bidRevenue"))
.show
对于“显示”操作之前的转换,所有方法都可以正常工作。但是节目引发了错误:
scala.MatchError:[WrappedArray(21.4,24.9,22.0)](属于org.apache.spark.sql.execution.aggregate.InputAggregationBuffer类) 在org.apache.spark.sql.catalyst.CatalystTypeConverters $ ArrayConverter.toCatalystImpl(CatalystTypeConverters.scala:160)
我的数据框的结构是:
root
|-- siteId: integer (nullable = false)
|-- bidRevenue: array (nullable = true)
| |-- element: double (containsNull = true)
df.dtypes = Array [(String,String)] = Array((“ siteId”,“ IntegerType”),(“ bidRevenue”,“ ArrayType(DoubleType,true)”))
为您提供宝贵的帮助。
答案 0 :(得分:0)
def evaluate(buffer: Row): Any
一旦完全处理了一个组以获得最终结果,就会调用上述方法。 在初始化和更新缓冲区的第0个索引时
i.e. buffer(0)
因此,由于汇总结果存储在0索引中,因此您需要在最后返回第0个索引值。
def evaluate(buffer: Row) = {
buffer.get(0)
}
对评估方法()进行以上修改将导致:
// +------+---------------------------------+
// |siteId|bidaggregatorbysiteid(bidRevenue)|
// +------+---------------------------------+
// | 1| [21.4, 24.9, 22.0]|
// | 2| [22.2, 27.2, 19.4]|
// +------+---------------------------------+