collect_set等效spark 1.5 UDAF方法验证

时间:2016-10-12 05:05:34

标签: scala apache-spark apache-spark-sql user-defined-functions

有人可以在spark 1.5中告诉我collect_set的等效函数吗?

是否有任何解决方法可以获得类似于collect_set(col(name))的类似结果?

这是正确的做法:

class CollectSetFunction[T](val colType: DataType) extends UserDefinedAggregateFunction {

  def inputSchema: StructType =
    new StructType().add("inputCol", colType)

  def bufferSchema: StructType =
    new StructType().add("outputCol", ArrayType(colType))

  def dataType: DataType = ArrayType(colType)

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, new scala.collection.mutable.ArrayBuffer[T])
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val list = buffer.getSeq[T](0)
    if (!input.isNullAt(0)) {
      val sales = input.getAs[T](0)
      buffer.update(0, list:+sales)
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getSeq[T](0).toSet ++ buffer2.getSeq[T](0).toSet)
  }

  def evaluate(buffer: Row): Any = {
    buffer.getSeq[T](0)
  }
}

1 个答案:

答案 0 :(得分:2)

代码看起来正确。此外,我在1.6.2中以本地模式测试并得到了相同的结果(见下文)。我不知道使用DataFrame API的任何更简单的替代方案。使用RDD,它非常简单,最好在1.5中绕过RDD API,因为数据帧没有完全实现。

scala> val rdd = sc.parallelize((1 to 10)).map(x => (x%5,x))
scala> rdd.groupByKey.mapValues(_.toSet.toList)).toDF("k","set").show
+---+-------+
|  k|    set|
+---+-------+
|  0|[5, 10]|
|  1| [1, 6]|
|  2| [2, 7]|
|  3| [3, 8]|
|  4| [4, 9]|
+---+-------+

如果您想将其分解出来,初始版本(可以改进)可以是以下

def collectSet(df: DataFrame, k: Column, v: Column) = df
    .select(k.as("k"),v.as("v"))
    .map( r => (r.getInt(0),r.getInt(1)))
    .groupByKey()
    .mapValues(_.toSet.toList)
    .toDF("k","v")

但是如果你想进行其他聚合,你将无法避免加入。

scala> val df = sc.parallelize((1 to 10)).toDF("v").withColumn("k", pmod('v,lit(5)))
df: org.apache.spark.sql.DataFrame = [v: int, k: int]

scala> val csudaf = new CollectSetFunction[Int](IntegerType)

scala> df.groupBy('k).agg(collect_set('v),csudaf('v)).show
+---+--------------+---------------------+
|  k|collect_set(v)|CollectSetFunction(v)|
+---+--------------+---------------------+
|  0|       [5, 10]|              [5, 10]|
|  1|        [1, 6]|               [1, 6]|
|  2|        [2, 7]|               [2, 7]|
|  3|        [3, 8]|               [3, 8]|
|  4|        [4, 9]|               [4, 9]|
+---+--------------+---------------------+

测试2:

scala> val df = sc.parallelize((1 to 100000)).toDF("v").withColumn("k", floor(rand*10))
df: org.apache.spark.sql.DataFrame = [v: int, k: bigint]

scala> df.groupBy('k).agg(collect_set('v).as("a"),csudaf('v).as("b"))
         .groupBy('a==='b).count.show
+-------+-----+                                                                 
|(a = b)|count|
+-------+-----+
|   true|   10|
+-------+-----+