我正在处理一个大型spark DataFrame中的一列数字,我想创建一个新列,该列存储该列中出现的唯一数字的聚合列表。
基本上就是functions.collect_set的作用。但是,我只需要聚合列表中最多1000个元素。有没有办法以某种方式将该参数传递给functions.collect_set(),或者在不使用UDAF的情况下以任何其他方式在聚合列表中仅获取最多1000个元素?
由于列太大,我想避免收集所有元素并在之后修剪列表。
谢谢!
答案 0 :(得分:5)
我的解决方案与Loki's answer with collect_set_limit
非常相似。
我使用的UDF会在collect_set
(或collect_list
)或更难的UDAF之后执行您想要的操作。
鉴于UDF的更多经验,我首先要考虑它。尽管UDF没有经过优化,但对于这个用例来说还不错。
val limitUDF = udf { (nums: Seq[Long], limit: Int) => nums.take(limit) }
val sample = spark.range(50).withColumn("key", $"id" % 5)
scala> sample.groupBy("key").agg(collect_set("id") as "all").show(false)
+---+--------------------------------------+
|key|all |
+---+--------------------------------------+
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|
+---+--------------------------------------+
scala> sample.
groupBy("key").
agg(collect_set("id") as "all").
withColumn("limit(3)", limitUDF($"all", lit(3))).
show(false)
+---+--------------------------------------+------------+
|key|all |limit(3) |
+---+--------------------------------------+------------+
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|[0, 15, 30] |
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|[1, 16, 31] |
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|[33, 48, 13]|
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|[12, 27, 37]|
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|[9, 19, 34] |
+---+--------------------------------------+------------+
请参阅functions对象(适用于udf
功能的文档)。
答案 1 :(得分:4)
我正在使用collect_set和collect_list函数的修改副本;由于代码范围,修改后的副本必须与原始文件位于相同的包路径中。链接代码适用于Spark 2.1.0;如果您使用的是先前版本,则方法签名可能会有所不同。
将此文件(https://gist.github.com/lokkju/06323e88746c85b2ce4de3ea9cdef9bc)作为src / main / org / apache / spark / sql / catalyst / expression / collect_limit.scala
投入项目将其用作:
import org.apache.spark.sql.catalyst.expression.collect_limit._
df.groupBy('set_col).agg(collect_set_limit('set_col,1000)
答案 2 :(得分:2)
scala> df.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 10| Name|2016| Country|
| 11|Name1|2016|country1|
| 10| Name|2016| Country|
| 10| Name|2016| Country|
| 12|Name2|2017|Country2|
+---+-----+----+--------+
scala> df.groupBy("C1").agg(sum("C0"))
res36: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res36.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
|Name1| 11|
|Name2| 12|
| Name| 30|
+-----+-------+
scala> df.limit(2).groupBy("C1").agg(sum("C0"))
res33: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res33.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
| Name| 10|
|Name1| 11|
+-----+-------+
scala> df.groupBy("C1").agg(sum("C0")).limit(2)
res2: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint]
scala> res2.show
+-----+-------+
| C1|sum(C0)|
+-----+-------+
|Name1| 11|
|Name2| 12|
+-----+-------+
scala> df.distinct
res8: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]
scala> res8.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 10| Name|2016| Country|
| 12|Name2|2017|Country2|
+---+-----+----+--------+
scala> df.dropDuplicates(Array("c1"))
res11: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string]
scala> res11.show
+---+-----+----+--------+
| C0| C1| C2| C3|
+---+-----+----+--------+
| 11|Name1|2016|country1|
| 12|Name2|2017|Country2|
| 10| Name|2016| Country|
+---+-----+----+--------+
答案 3 :(得分:1)
使用take
val firstThousand = rdd.take(1000)
将返回前1000。 Collect还具有可以提供的过滤功能。这将使您能够更具体地了解返回的内容。
答案 4 :(得分:1)
正如其他答案所提到的那样,执行此操作的高性能方法是编写UDAF。不幸的是,UDAF API实际上不像spark附带的聚合函数那样可扩展。但是,您可以使用它们的内部API来构建内部函数以完成所需的工作。
这里是collect_set_limit
的实现,大部分是Spark内部CollectSet
AggregateFunction的复制版本。我只是扩展它,但是它是一个案例类。真正需要做的就是重写update和merge方法以遵守传入的限制:
case class CollectSetLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {
val limit = limitExp.eval( null ).asInstanceOf[Int]
def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty
override def update(buffer: mutable.HashSet[Any], input: InternalRow): mutable.HashSet[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}
override def merge(buffer: mutable.HashSet[Any], other: mutable.HashSet[Any]): mutable.HashSet[Any] = {
if( buffer.size >= limit ) buffer
else buffer ++= other.take( limit - buffer.size )
}
override def prettyName: String = "collect_set_limit"
}
要进行实际注册,我们可以通过Spark内部的FunctionRegistry
(使用名称)和构建器来进行注册,该构建器实际上是使用提供的表达式创建CollectSetLimit
的函数:
val collectSetBuilder = (args: Seq[Expression]) => CollectSetLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_set_limit", collectSetBuilder )
编辑:
结果证明,只有在尚未创建SparkContext的情况下,才能将其添加到内置组件中,因为它会在启动时生成不可变的克隆。如果您已有一个上下文,那么应该使用反射将其添加:
val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_set_limit", collectSetBuilder )