有人可以在spark 1.5中告诉我collect_set的等效函数吗?
class CollectSetFunction[T](val colType: DataType) extends UserDefinedAggregateFunction {
def inputSchema: StructType =
new StructType().add("inputCol", colType)
def bufferSchema: StructType =
new StructType().add("outputCol", ArrayType(colType))
def dataType: DataType = ArrayType(colType)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, new scala.collection.mutable.ArrayBuffer[T])
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val list = buffer.getSeq[T](0)
if (!input.isNullAt(0)) {
val sales = input.getAs[T](0)
buffer.update(0, list:+sales)
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getSeq[T](0).toSet ++ buffer2.getSeq[T](0).toSet)
def evaluate(buffer: Row): Any = {
答案 0 :(得分:2)
API的任何更简单的替代方案。使用RDD,它非常简单,最好在1.5中绕过RDD API,因为数据帧没有完全实现。
scala> val rdd = sc.parallelize((1 to 10)).map(x => (x%5,x))
scala> rdd.groupByKey.mapValues(_.toSet.toList)).toDF("k","set").show
| k| set|
| 0|[5, 10]|
| 1| [1, 6]|
| 2| [2, 7]|
| 3| [3, 8]|
| 4| [4, 9]|
def collectSet(df: DataFrame, k: Column, v: Column) = df
.map( r => (r.getInt(0),r.getInt(1)))
scala> val df = sc.parallelize((1 to 10)).toDF("v").withColumn("k", pmod('v,lit(5)))
df: org.apache.spark.sql.DataFrame = [v: int, k: int]
scala> val csudaf = new CollectSetFunction[Int](IntegerType)
scala> df.groupBy('k).agg(collect_set('v),csudaf('v)).show
| k|collect_set(v)|CollectSetFunction(v)|
| 0| [5, 10]| [5, 10]|
| 1| [1, 6]| [1, 6]|
| 2| [2, 7]| [2, 7]|
| 3| [3, 8]| [3, 8]|
| 4| [4, 9]| [4, 9]|
scala> val df = sc.parallelize((1 to 100000)).toDF("v").withColumn("k", floor(rand*10))
df: org.apache.spark.sql.DataFrame = [v: int, k: bigint]
scala> df.groupBy('k).agg(collect_set('v).as("a"),csudaf('v).as("b"))
|(a = b)|count|
| true| 10|