有没有办法将限制参数传递给Spark中的functions.collect_set?

时间:2016-08-02 21:39:06

标签: apache-spark dataframe apache-spark-sql aggregate-functions

我正在处理一个大型spark DataFrame中的一列数字,我想创建一个新列,该列存储该列中出现的唯一数字的聚合列表。

基本上就是functions.collect_set的作用。但是,我只需要聚合列表中最多1000个元素。有没有办法以某种方式将该参数传递给functions.collect_set(),或者在不使用UDAF的情况下以任何其他方式在聚合列表中仅获取最多1000个元素?

由于列太大,我想避免收集所有元素并在之后修剪列表。

谢谢!

5 个答案:

答案 0 :(得分:5)

我的解决方案与Loki's answer with collect_set_limit非常相似。

我使用的UDF会在collect_set(或collect_list)或更难的UDAF之后执行您想要的操作。

鉴于UDF的更多经验,我首先要考虑它。尽管UDF没有经过优化,但对于这个用例来说还不错。

val limitUDF = udf { (nums: Seq[Long], limit: Int) => nums.take(limit) }
val sample = spark.range(50).withColumn("key", $"id" % 5)

scala> sample.groupBy("key").agg(collect_set("id") as "all").show(false)
+---+--------------------------------------+
|key|all                                   |
+---+--------------------------------------+
|0  |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|
|1  |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|
|3  |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|
|2  |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|
|4  |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|
+---+--------------------------------------+

scala> sample.
  groupBy("key").
  agg(collect_set("id") as "all").
  withColumn("limit(3)", limitUDF($"all", lit(3))).
  show(false)
+---+--------------------------------------+------------+
|key|all                                   |limit(3)    |
+---+--------------------------------------+------------+
|0  |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|[0, 15, 30] |
|1  |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|[1, 16, 31] |
|3  |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|[33, 48, 13]|
|2  |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|[12, 27, 37]|
|4  |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|[9, 19, 34] |
+---+--------------------------------------+------------+

请参阅functions对象(适用于udf功能的文档)。

答案 1 :(得分:4)

我正在使用collect_set和collect_list函数的修改副本;由于代码范围,修改后的副本必须与原始文件位于相同的包路径中。链接代码适用于Spark 2.1.0;如果您使用的是先前版本,则方法签名可能会有所不同。

将此文件(https://gist.github.com/lokkju/06323e88746c85b2ce4de3ea9cdef9bc)作为src / main / org / apache / spark / sql / catalyst / expression / collect_limit.scala

投入项目

将其用作:

import org.apache.spark.sql.catalyst.expression.collect_limit._
df.groupBy('set_col).agg(collect_set_limit('set_col,1000)

答案 2 :(得分:2)

 scala> df.show
    +---+-----+----+--------+
    | C0|   C1|  C2|      C3|
    +---+-----+----+--------+
    | 10| Name|2016| Country|
    | 11|Name1|2016|country1|
    | 10| Name|2016| Country|
    | 10| Name|2016| Country|
    | 12|Name2|2017|Country2|
    +---+-----+----+--------+

scala> df.groupBy("C1").agg(sum("C0"))
res36: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]

scala> res36.show
+-----+-------+
|   C1|sum(C0)|
+-----+-------+
|Name1|     11|
|Name2|     12|
| Name|     30|
+-----+-------+

scala> df.limit(2).groupBy("C1").agg(sum("C0"))
    res33: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]

    scala> res33.show
    +-----+-------+
    |   C1|sum(C0)|
    +-----+-------+
    | Name|     10|
    |Name1|     11|
    +-----+-------+



    scala> df.groupBy("C1").agg(sum("C0")).limit(2)
res2: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]

scala> res2.show
+-----+-------+
|   C1|sum(C0)|
+-----+-------+
|Name1|     11|
|Name2|     12|
+-----+-------+

scala> df.distinct
res8: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]

scala> res8.show
+---+-----+----+--------+
| C0|   C1|  C2|      C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 10| Name|2016| Country|
| 12|Name2|2017|Country2|
+---+-----+----+--------+

scala> df.dropDuplicates(Array("c1"))
res11: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]

scala> res11.show
+---+-----+----+--------+                                                       
| C0|   C1|  C2|      C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 12|Name2|2017|Country2|
| 10| Name|2016| Country|
+---+-----+----+--------+

答案 3 :(得分:1)

使用take

val firstThousand = rdd.take(1000)

将返回前1000。 Collect还具有可以提供的过滤功能。这将使您能够更具体地了解返回的内容。

答案 4 :(得分:1)

正如其他答案所提到的那样,执行此操作的高性能方法是编写UDAF。不幸的是,UDAF API实际上不像spark附带的聚合函数那样可扩展。但是,您可以使用它们的内部API来构建内部函数以完成所需的工作。

这里是collect_set_limit的实现,大部分是Spark内部CollectSet AggregateFunction的复制版本。我只是扩展它,但是它是一个案例类。真正需要做的就是重写update和merge方法以遵守传入的限制:

case class CollectSetLimit(
    child: Expression,
    limitExp: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {

  val limit = limitExp.eval( null ).asInstanceOf[Int]

  def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty

  override def update(buffer: mutable.HashSet[Any], input: InternalRow): mutable.HashSet[Any] = {
    if( buffer.size < limit ) super.update(buffer, input)
    else buffer
  }

  override def merge(buffer: mutable.HashSet[Any], other: mutable.HashSet[Any]): mutable.HashSet[Any] = {
    if( buffer.size >= limit ) buffer
    else buffer ++= other.take( limit - buffer.size )
  }

  override def prettyName: String = "collect_set_limit"
}

要进行实际注册,我们可以通过Spark内部的FunctionRegistry(使用名称)和构建器来进行注册,该构建器实际上是使用提供的表达式创建CollectSetLimit的函数:

val collectSetBuilder = (args: Seq[Expression]) => CollectSetLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_set_limit", collectSetBuilder )

编辑:

结果证明,只有在尚未创建SparkContext的情况下,才能将其添加到内置组件中,因为它会在启动时生成不可变的克隆。如果您已有一个上下文,那么应该使用反射将其添加:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_set_limit", collectSetBuilder )