有人可以在spark 1.5中告诉我collect_set的等效函数吗?
是否有任何解决方法可以获得类似于collect_set(col(name))的类似结果?
这是正确的做法:
class CollectSetFunction[T](val colType: DataType) extends UserDefinedAggregateFunction {
def inputSchema: StructType =
new StructType().add("inputCol", colType)
def bufferSchema: StructType =
new StructType().add("outputCol", ArrayType(colType))
def dataType: DataType = ArrayType(colType)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, new scala.collection.mutable.ArrayBuffer[T])
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val list = buffer.getSeq[T](0)
if (!input.isNullAt(0)) {
val sales = input.getAs[T](0)
buffer.update(0, list:+sales)
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getSeq[T](0).toSet ++ buffer2.getSeq[T](0).toSet)
}
def evaluate(buffer: Row): Any = {
buffer.getSeq[T](0)
}
}
答案 0 :(得分:2)
代码看起来正确。此外,我在1.6.2中以本地模式测试并得到了相同的结果(见下文)。我不知道使用DataFrame
API的任何更简单的替代方案。使用RDD,它非常简单,最好在1.5中绕过RDD API,因为数据帧没有完全实现。
scala> val rdd = sc.parallelize((1 to 10)).map(x => (x%5,x))
scala> rdd.groupByKey.mapValues(_.toSet.toList)).toDF("k","set").show
+---+-------+
| k| set|
+---+-------+
| 0|[5, 10]|
| 1| [1, 6]|
| 2| [2, 7]|
| 3| [3, 8]|
| 4| [4, 9]|
+---+-------+
如果您想将其分解出来,初始版本(可以改进)可以是以下
def collectSet(df: DataFrame, k: Column, v: Column) = df
.select(k.as("k"),v.as("v"))
.map( r => (r.getInt(0),r.getInt(1)))
.groupByKey()
.mapValues(_.toSet.toList)
.toDF("k","v")
但是如果你想进行其他聚合,你将无法避免加入。
scala> val df = sc.parallelize((1 to 10)).toDF("v").withColumn("k", pmod('v,lit(5)))
df: org.apache.spark.sql.DataFrame = [v: int, k: int]
scala> val csudaf = new CollectSetFunction[Int](IntegerType)
scala> df.groupBy('k).agg(collect_set('v),csudaf('v)).show
+---+--------------+---------------------+
| k|collect_set(v)|CollectSetFunction(v)|
+---+--------------+---------------------+
| 0| [5, 10]| [5, 10]|
| 1| [1, 6]| [1, 6]|
| 2| [2, 7]| [2, 7]|
| 3| [3, 8]| [3, 8]|
| 4| [4, 9]| [4, 9]|
+---+--------------+---------------------+
测试2:
scala> val df = sc.parallelize((1 to 100000)).toDF("v").withColumn("k", floor(rand*10))
df: org.apache.spark.sql.DataFrame = [v: int, k: bigint]
scala> df.groupBy('k).agg(collect_set('v).as("a"),csudaf('v).as("b"))
.groupBy('a==='b).count.show
+-------+-----+
|(a = b)|count|
+-------+-----+
| true| 10|
+-------+-----+