我有一些像这样的代码:
case class ns(name : String, age : Integer, grp : Integer)
val df = List(
ns("sara", 45, 1),
ns("bob", 45, 1),
ns("lucy", 5, 2),
ns("paul", 5, 3),
ns("cindyluhoo", 1, 3),
ns("simona", 45, 2)
).toDF
df.show
//+----------+---+---+
//| name|age|grp|
//+----------+---+---+
//| sara| 45| 1|
//| bob| 45| 1|
//| lucy| 5| 2|
//| paul| 5| 3|
//|cindyluhoo| 1| 3|
//| simona| 45| 2|
//+----------+---+---+
df.orderBy($"age", $"grp").show
//+----------+---+---+
//| name|age|grp|
//+----------+---+---+
//|cindyluhoo| 1| 3|
//| lucy| 5| 2|
//| paul| 5| 3|
//| bob| 45| 1|
//| sara| 45| 1|
//| simona| 45| 2|
//+----------+---+---+
df.orderBy($"grp").groupBy($"age").agg( collect_set($"name").as('lst)).as('fs).show
//+---+-------------------+
//|age| lst|
//+---+-------------------+
//| 45|[sara, bob, simona]|
//| 1| [cindyluhoo]|
//| 5| [lucy, paul]|
//+---+-------------------+
我想要这个(年龄顺序并不重要,只要它的分组 - 见输出):
//+----------+---+---+---------------------+
//| name|age|grp| windowedData |
//+----------+---+---+---------------------+
//|cindyluhoo| 1| 3| [cindyluhoo] |
//| lucy| 5| 2| [lucy] |
//| paul| 5| 3| [lucy, paul] |
//| bob| 45| 1| [bob] |
//| sara| 45| 1| [bob, sara] |
//| simona| 45| 2| [bob, sara, simona] |
//+----------+---+---+---------------------+
我知道我必须上窗,所以我测试一下:
val windowSpec = Window.partitionBy($"age").orderBy($"grp")
val indexCol = row_number.over(windowSpec) - 1
df.withColumn("ind", indexCol).show
//+----------+---+---+---+
//| name|age|grp|ind|
//+----------+---+---+---+
//| sara| 45| 1| 1|
//| bob| 45| 1| 2|
//| simona| 45| 2| 3|
//|cindyluhoo| 1| 3| 1|
//| lucy| 5| 2| 1|
//| paul| 5| 3| 2|
//+----------+---+---+---+
//
现在,我将尝试获取WrappedArray:
编辑:修正以下行的拼写错误:
val accNamesCol = lag($"name",0).over(windowSpec)
df.withColumn("ind", indexCol).withColumn("accNames", accNamesCol).show
//+----------+---+---+---+----------+
//| name|age|grp|ind| accNames|
//+----------+---+---+---+----------+
//| sara| 45| 1| 1| sara|
//| bob| 45| 1| 2| bob|
//| simona| 45| 2| 3| simona|
//|cindyluhoo| 1| 3| 1|cindyluhoo|
//| lucy| 5| 2| 1| lucy|
//| paul| 5| 3| 2| paul|
//+----------+---+---+---+----------+
这不是我想要的,所以试试这个:
df.withColumn("ind", indexCol).withColumn("accNames", collect_set(accNamesCol)).show
适用于此论坛,最后一行崩溃
java.lang.StackOverflowError
at java.lang.ThreadLocal.get(ThreadLocal.java:161)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.get(TreeNode.scala:57)
如何使用窗口功能? (我也想要ind
列。)
编辑#1 :(已添加部分):
为了澄清,我正在寻找一种方法来汇总名称col。最终输出应如下所示:
//+----------+---+---+---+---------------------+
//| name|age|grp|ind| accNames |
//+----------+---+---+---+---------------------+
//| sara| 45| 1| 1| [sara] |
//| bob| 45| 1| 2| [sara, bob] |
//| simona| 45| 2| 3| [sara, bob, simona] |
//|cindyluhoo| 1| 3| 1| [cindyluhoo] |
//| lucy| 5| 2| 1| [lucy] |
//| paul| 5| 3| 2| [lucy, paul] |
//+----------+---+---+---+---------------------+
编辑#2:最后一列accNames
的类型为scala.collection.mutable.WrappedArray[String]
。
添加信息(另一个编辑):
我不确定目标是否可以通过collect_set实现或者涉及到UDAF并希望这样(测试):
case class ns(name : String, age : Integer, grp : Integer)
// examples
// https://issues.apache.org/jira/browse/SPARK-11372
// https://issues.apache.org/jira/browse/SPARK-11885
class AggStringSet extends UserDefinedAggregateFunction {
import org.apache.spark.sql.types._
// inputSchema returns a StructType and every field of this StructType represents an input argument of this UDAF.
override def inputSchema: StructType = StructType(StructField("value", StringType) :: Nil)
// returns a StructType and every field of this StructType represents a field of this UDAF’s intermediate results.
override def bufferSchema: StructType = StructType(Seq(StructField("merge", ArrayType(StringType,true),true)))
// dataType returns a DataType representing the data type of this UDAF’s returned value
override def dataType: DataType = ArrayType(StringType,true)
// indicate if this UDAF always generate the same result for a given set of input values.
override def deterministic: Boolean = true
// used to initialize values of an aggregation buffer, represented by a MutableAggregationBuffer.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Seq[String]()
}
// Iterate over each entry of a group
// update an aggregation buffer represented by a MutableAggregationBuffer for an input Row
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// println("update1:" + input.getAs[String](0))
buffer(0) = buffer.getAs[Seq[String]](0) :+ input.getAs[String](0)
// println("update2:" + buffer(0))
// buffer(0)
}
// merge two aggregation buffers and store the result to a MutableAggregationBuffer.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// println("merge1: " + buffer1.getAs[Seq[String]](0))
// println("merge2: " + buffer2.getAs[Seq[String]](0))
buffer1(0) = buffer1.getAs[Seq[String]](0) ++ buffer2.getAs[Seq[String]](0)
}
// generate the final result value of this UDAF based on values stored in an aggregation buffer represented by a Row.
override def evaluate(buffer: Row): Any = {
// println("eval: "+ buffer.getAs[Seq[String]](0))
// println("eval: "+ buffer.getAs[Any](0))
buffer.getAs[Seq[String]](0)
}
}