我从我的mongodb获得了一些看起来像这样的数据:
+------+-------+
| view | data |
+------+-------+
| xx | *** |
| yy | *** |
| xx | *** |
+------+-------+
没有必要知道里面是什么。
我写了一个像这样的UserDefinedAggregateFunction,因为我想在视图上分组。:
class Extractor() extends UserDefinedAggregateFunction{
override def inputSchema: StructType = // some stuff
override def bufferSchema: StructType =
StructType(
List(
StructField("0",IntegerType,false),
StructField("1",IntegerType,false),
StructField("2",IntegerType,false),
StructField("3",IntegerType,false),
StructField("4",IntegerType,false),
StructField("5",IntegerType,false),
StructField("6",IntegerType,false),
StructField("7",IntegerType,false)
)
)
override def dataType: DataType = bufferSchema
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
for (x <- 0 to 7){
buffer(x) = 0
}
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = // some stuff
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = // some stuff
override def evaluate(buffer: Row): Any =
var l = List.empty[Integer]
for (x <- 7 to 0 by -1){
l = buffer.getInt(x) :: l
}
l
}
我的输出应该是这样的:
+------+---+---+---+---+---+---+---+---+
| view | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
+------+---+---+---+---+---+---+---+---+
| xx | 0 | 0 | 4 | 1 | 0 | 0 | 3 | 0 |
| yy | 0 | 0 | 0 | 3 | 0 | 1 | 0 | 0 |
+------+---+---+---+---+---+---+---+---+
这些值是在上面的更新/合并功能中计算出来的,但是没有必要让你看到它。
然后我像这样使用它:
val ex = new Extractor()
val df = dataset.groupBy("view").agg(
ex(dataset.col("data"))
)
df.show()
当我执行df.show()时,它总是给我一个IndexOutOfBoundException。我知道这是懒惰的评估,这就是我在df.show()中出错的原因。
据我所知,它可以执行第一组并结束evaluate函数。但之后我得到一个IndexOutOfBoundException ......
当我更改dataType并将Function评估为:
时override def dataType: DataType =
ArrayType(IntegerType,false)
override def evaluate(buffer: Row): Any = {
var l = ofDim[Integer](8)
for (x <- 0 to 7){
l(x) = buffer.getInt(x)
}
l
输出如下:
+------+------------------------------+
| view | Extractor |
+------+------------------------------+
| xx | [0, 0, 4, 1, 0, 0, 3, 0] |
| yy | [0, 0, 0, 3, 0, 1, 0, 0] |
+------+------------------------------+
架构看起来像这样:
root
|-- view: string (nullable = true)
|-- Extractor: array (nullable = true)
| |-- element: integer (containsNull = false)
我无法以我想要的形式转换它。
因为第二种方法有效,我认为我在第一种方法中弄乱了DataType,但我不知道如何修复它......
许多介绍我的问题:
如何获得我想要的输出? 我并不关心这两种方法中的哪一种(首先是多输出列或者可以转换为我想要的形式的数组),只要它有效。
感谢您的帮助
答案 0 :(得分:0)
您将聚合输出定义为列表:
override def dataType: DataType = bufferSchema
因为bufferSchema
是一个列表,所以这就是你最终得到的。您可以稍后更改架构,并将列表中的每个列转换为新列。
对于您的错误,区别:
override def evaluate(buffer: Row): Any =
var l = List.empty[Integer]
for (x <- 7 to 0 by -1){
l = buffer.getInt(x) :: l
}
l
和
override def evaluate(buffer: Row): Any =
var l = ofDim[Integer](8)
for (x <- 0 to 7){
l = buffer.getInt(x) :: l
}
l
是在第二个中,您定义了预定义数量的列。因此,您确定可以毫无问题地从0到7进行迭代。
您的第一个示例并非如此,因此,我怀疑您可能有错误格式的数据会导致您的缓冲区在initialize
或merge
中被错误地初始化。我建议你添加一个try / catch来验证缩放缓冲区长度的每个步骤之后的大小(至少initialize
,但也可以是update
或merge
。< / p>
要为列表中的每个元素添加列,您可以使用withColumn
或通过地图执行此操作。