Spark collect_list并限制结果列表

时间:2018-09-23 15:25:09

标签: scala apache-spark dataframe limit

我具有以下格式的数据框:

name          merged
key1    (internalKey1, value1)
key1    (internalKey2, value2)
...
key2    (internalKey3, value3)
...

我想做的是将数据框按name分组,收集列表并限制列表的大小。

这是我按name分组并收集列表的方式:

val res = df.groupBy("name")
            .agg(collect_list(col("merged")).as("final"))

可恢复的数据帧类似于:

 key1   [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list 
 key2   [(internalKey3, value3),...]

我想做的是限制每个键的生成列表的大小。我已经尝试了多种方法来做到这一点,但没有成功。我已经看过一些建议第三者解决方案的帖子,但我想避免这种情况。有办法吗?

3 个答案:

答案 0 :(得分:3)

您可以创建一个函数来限制聚合的ArrayType列的大小,如下所示:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.Column

case class KV(k: String, v: String)

val df = Seq(
  ("key1", KV("internalKey1", "value1")),
  ("key1", KV("internalKey2", "value2")),
  ("key2", KV("internalKey3", "value3")),
  ("key2", KV("internalKey4", "value4")),
  ("key2", KV("internalKey5", "value5"))
).toDF("name", "merged")

def limitSize(n: Int, arrCol: Column): Column =
  array( (0 until n).map( arrCol.getItem ): _* )

df.
  groupBy("name").agg( collect_list(col("merged")).as("final") ).
  select( $"name", limitSize(2, $"final").as("final2") ).
  show(false)
// +----+----------------------------------------------+
// |name|final2                                        |
// +----+----------------------------------------------+
// |key1|[[internalKey1,value1], [internalKey2,value2]]|
// |key2|[[internalKey3,value3], [internalKey4,value4]]|
// +----+----------------------------------------------+

答案 1 :(得分:0)

因此,尽管UDF可以满足您的需求,但是如果您正在寻找一种对内存也敏感的性能更高的方法,那么这样做的方法就是编写UDAF。不幸的是,UDAF API实际上不像spark附带的聚合函数那样可扩展。但是,您可以使用它们的内部API来构建内部函数以完成所需的工作。

这里是collect_list_limit的实现,大部分是Spark内部CollectList AggregateFunction的复制版本。我只是扩展它,但是它是一个案例类。真正需要做的就是重写update和merge方法以遵守传入的限制:

case class CollectListLimit(
    child: Expression,
    limitExp: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {

  val limit = limitExp.eval( null ).asInstanceOf[Int]

  def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

  override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
    if( buffer.size < limit ) super.update(buffer, input)
    else buffer
  }

  override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
    if( buffer.size >= limit ) buffer
    else if( other.size >= limit ) other
    else ( buffer ++= other ).take( limit )
  }

  override def prettyName: String = "collect_list_limit"
}

要进行实际注册,我们可以通过Spark内部的FunctionRegistry(使用名称)和构建器来进行注册,该构建器实际上是使用提供的表达式创建CollectListLimit的函数:

val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )

编辑:

结果证明,只有在尚未创建SparkContext的情况下,才能将其添加到内置组件中,因为它会在启动时生成不可变的克隆。如果您已有一个上下文,那么应该使用反射将其添加:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )

答案 2 :(得分:0)

您可以使用UDF

这是一个可能的例子,不需要模式,并且有意义地简化了

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._

import scala.collection.mutable


object TestJob1 {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  ("key", 1L, "gargamel"),
  ("key", 4L, "pe_gadol"),
  ("key", 2L, "zaam"),
  ("key1", 5L, "naval")
).toDF("group", "quality", "other")

rawDf.show(false)
rawDf.printSchema

val rawSchema = rawDf.schema

val fUdf = udf(reduceByQuality, rawSchema)

val aggDf = rawDf
  .groupBy("group")
  .agg(
    count(struct("*")).as("num_reads"),
    max(col("quality")).as("quality"),
    collect_list(struct("*")).as("horizontal")
  )
  .withColumn("short", fUdf($"horizontal"))
  .drop("horizontal")


aggDf.printSchema

aggDf.show(false)
}

def reduceByQuality= (x: Any) => {

val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

val red = d.reduce((r1, r2) => {

  val quality1 = r1.getAs[Long]("quality")
  val quality2 = r2.getAs[Long]("quality")

  val r3 = quality1 match {
    case a if a >= quality2 =>
      r1
    case _ =>
      r2
  }

  r3
})

red
}
}

这是一个与您的数据类似的示例

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.functions._

import scala.collection.mutable


object TestJob {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._


val df1 = Seq(
  ("key1", ("internalKey1", "value1")),
  ("key1", ("internalKey2", "value2")),
  ("key2", ("internalKey3", "value3")),
  ("key2", ("internalKey4", "value4")),
  ("key2", ("internalKey5", "value5"))
)
  .toDF("name", "merged")

//    df1.printSchema
//
//    df1.show(false)

val res = df1
  .groupBy("name")
  .agg( collect_list(col("merged")).as("final") )

res.printSchema

res.show(false)

def f= (x: Any) => {

  val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

  val d1 = d.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]].head

  d1.toString
}

val fUdf = udf(f, StringType)

val d2 = res
  .withColumn("d", fUdf(col("final")))
  .drop("final")

d2.printSchema()

d2
  .show(false)
 }
 }