Spark将sql窗口函数迁移到RDD以获得更好的性能

时间:2017-01-03 14:19:49

标签: scala apache-spark apache-spark-sql rdd

应该为数据框中的多个列执行一个函数

def handleBias(df: DataFrame, colName: String, target: String = target) = {
    val w1 = Window.partitionBy(colName)
    val w2 = Window.partitionBy(colName, target)

    df.withColumn("cnt_group", count("*").over(w2))
      .withColumn("pre2_" + colName, mean(target).over(w1))
      .withColumn("pre_" + colName, coalesce(min(col("cnt_group") / col("cnt_foo_eq_1")).over(w1), lit(0D)))
      .drop("cnt_group")
  }

这可以很好地编写,如上面的spark-SQL和for循环中所示。然而,这导致了很多混乱(spark apply function to columns in parallel)。

一个最小的例子:

  val df = Seq(
    (0, "A", "B", "C", "D"),
    (1, "A", "B", "C", "D"),
    (0, "d", "a", "jkl", "d"),
    (0, "d", "g", "C", "D"),
    (1, "A", "d", "t", "k"),
    (1, "d", "c", "C", "D"),
    (1, "c", "B", "C", "D")
  ).toDF("TARGET", "col1", "col2", "col3TooMany", "col4")

  val columnsToDrop = Seq("col3TooMany")
  val columnsToCode = Seq("col1", "col2")
  val target = "TARGET"

  val targetCounts = df.filter(df(target) === 1).groupBy(target)
    .agg(count(target).as("cnt_foo_eq_1"))
  val newDF = df.join(broadcast(targetCounts), Seq(target), "left")

  val result = (columnsToDrop ++ columnsToCode).toSet.foldLeft(newDF) {
    (currentDF, colName) => handleBias(currentDF, colName)
  }

  result.drop(columnsToDrop: _*).show

如何使用RDD API更有效地表达这一点? aggregateByKey应该是一个好主意,但我仍然不清楚如何在此处应用它来替换窗口函数。

(提供更多上下文/更大的示例https://github.com/geoHeil/sparkContrastCoding

修改

最初,我从Spark dynamic DAG is a lot slower and different from hard coded DAG开始,如下所示。好消息是,每列似乎都是独立/并行的。缺点是连接(即使对于300 MB的小数据集)也会变得太大而且#34;并导致一个反应迟钝的火花。

handleBiasOriginal("col1", df)
    .join(handleBiasOriginal("col2", df), df.columns)
    .join(handleBiasOriginal("col3TooMany", df), df.columns)
    .drop(columnsToDrop: _*).show

  def handleBiasOriginal(col: String, df: DataFrame, target: String = target): DataFrame = {
    val pre1_1 = df
      .filter(df(target) === 1)
      .groupBy(col, target)
      .agg((count("*") / df.filter(df(target) === 1).count).alias("pre_" + col))
      .drop(target)

    val pre2_1 = df
      .groupBy(col)
      .agg(mean(target).alias("pre2_" + col))

    df
      .join(pre1_1, Seq(col), "left")
      .join(pre2_1, Seq(col), "left")
      .na.fill(0)
  }

此图像为spark 2.1.0,Spark dynamic DAG is a lot slower and different from hard coded DAG的图像为2.0.2 toocomplexDAG

应用缓存时,DAG会更简单一些     df.cache     handleBiasOriginal(" col1",df)。 ...

您认为窗口函数还有哪些其他可能性来优化SQL? 如果SQL是动态生成的,那么它最好是很好的。

caching

2 个答案:

答案 0 :(得分:2)

这里的要点是避免不必要的洗牌。现在,您的代码会为您要包含的每个列重复两次,并且无法在列之间重复使用生成的数据布局。

为简单起见,我假设target始终是二进制({0,1}),并且您使用的所有剩余列都是StringType。此外,我假设列的基数足够低,以便将结果分组并在本地处理。您可以调整这些方法来处理其他情况,但需要更多工作。

RDD API

  • 从长到长重塑数据:

    import org.apache.spark.sql.functions._
    
    val exploded = explode(array(
      (columnsToDrop ++ columnsToCode).map(c => 
        struct(lit(c).alias("k"), col(c).alias("v"))): _*
    )).alias("level")
    
    val long = df.select(exploded, $"TARGET")
    
  • aggregateByKey,重塑并收集:

    import org.apache.spark.util.StatCounter
    
    val lookup = long.as[((String, String), Int)].rdd
      // You can use prefix partitioner (one that depends only on _._1)
      // to avoid reshuffling for groupByKey
      .aggregateByKey(StatCounter())(_ merge _, _ merge _)
      .map { case ((c, v), s) => (c, (v, s)) }
      .groupByKey
      .mapValues(_.toMap)
      .collectAsMap
    
  • 您可以使用lookup获取各个列和级别的统计信息。例如:

    lookup("col1")("A")
    
    org.apache.spark.util.StatCounter = 
      (count: 3, mean: 0.666667, stdev: 0.471405, max: 1.000000, min: 0.000000)
    

    为您提供col1,级别A的数据。基于二进制TARGET假设,此信息已完成(您可以获得两个类的计数/分数)。

    您可以使用这样的查找来生成SQL表达式或将其传递给udf并将其应用于各个列。

DataFrame API

  • 将数据转换为long,与RDD API一样。
  • 根据级别计算聚合:

    val stats = long
      .groupBy($"level.k", $"level.v")
      .agg(mean($"TARGET"), sum($"TARGET"))
    
  • 根据您的偏好,您可以对其进行重新整形,以实现有效的连接或转换为本地集合,类似于RDD解决方案。

答案 1 :(得分:0)

使用aggregateByKey 可以找到关于aggregateByKey的简单说明here。基本上你使用两个函数:一个在分区内工作,另一个在分区之间工作。

您需要在第一列进行聚合之类的操作,并在内部构建数据结构,并在第二列的每个元素上使用映射来聚合和收集数据(当然,如果需要,您可以执行两个aggregateByKey)。 这不会解决在您想要使用的每个列的代码上执行多次运行的情况(您可以使用聚合而不是aggregateByKey来处理所有数据并将其放在映射中但这可能会让您更糟糕性能)。结果将是每个键一行,如果你想回到原始记录(如窗口函数那样),你实际上需要将这个值与原始RDD连接或在内部保存所有值并平面图

我不相信这会为您带来任何真正的性能提升。你会做很多工作来重新实现在SQL中为你完成的事情,同时这样你将失去SQL的大部分优势(催化剂优化,钨存储器管理,整个阶段代码生成等)。

改进SQL

我要做的是尝试改进SQL本身。 例如,窗口函数中列的结果对于所有值看起来都是相同的。你真的需要一个窗口功能吗?你可以改为使用groupBy而不是窗口函数(如果你真的需要每个记录,你可以尝试加入结果。这可能会提供更好的性能,因为它不一定意味着每一步都要洗两次)。