如何通过另一个小数据帧(逐行)多​​次过滤一个大数据帧(等于小df的行数)?

时间:2018-05-27 07:52:06

标签: scala apache-spark apache-spark-sql

我有两个火花数据帧,dfAdfB。 我希望每行按dfA过滤dfB,这意味着如果dfB有10000行,我需要使用{{1}生成的10000个不同过滤条件过滤dfA 10000次}。然后,在每个过滤器后,我需要将过滤结果收集为dfB中的列。

dfB

所以我的预期结果是

dfA                                    dfB
+------+---------+---------+           +-----+-------------+--------------+
|  id  |  value1 |  value2 |           | id  |  min_value1 |  max_value1  |
+------+---------+---------+           +-----+-------------+--------------+            
|  1   |    0    |   4345  |           |  1  |     0       |       3      |
|  1   |    1    |   3434  |           |  1  |     5       |       9      |
|  1   |    2    |   4676  |           |  2  |     1       |       4      |
|  1   |    3    |   3454  |           |  2  |     6       |       8      |
|  1   |    4    |   9765  |           +-----+-------------+--------------+
|  1   |    5    |   5778  |           ....more rows, nearly 10000 rows.
|  1   |    6    |   5674  |
|  1   |    7    |   3456  |
|  1   |    8    |   6590  |
|  1   |    9    |   5461  |
|  1   |    10   |   4656  |
|  2   |    0    |   2324  |
|  2   |    1    |   2343  |
|  2   |    2    |   4946  |
|  2   |    3    |   4353  |
|  2   |    4    |   4354  |
|  2   |    5    |   3234  |
|  2   |    6    |   8695  |
|  2   |    7    |   6587  |
|  2   |    8    |   5688  |
+------+---------+---------+
......more rows,nearly one billons rows

我的愚蠢解决方案是

resultDF
+-----+-------------+--------------+----------------------------+
| id  |  min_value1 |  max_value1  |          results           |
+-----+-------------+--------------+----------------------------+            
|  1  |     0       |       3      | [4345,3434,4676,3454]      |
|  1  |     5       |       9      | [5778,5674,3456,6590,5461] |
|  2  |     1       |       4      | [2343,4946,4353,4354]      |
|  2  |     6       |       8      | [8695,6587,5688]           |
+-----+-------------+--------------+----------------------------+

也许你不想看到我的强力代码。这个想法是

def tempFunction(id:Int,dfA:DataFrame,dfB:DataFrame): DataFrame ={
    val dfa = dfA.filter("id ="+ id)
    val dfb = dfB.filter("id ="+ id)
    val arr = dfb.groupBy("id")
                 .agg(collect_list(struct("min_value1","max_value1"))
                 .collect()

    val rangArray = arr(0)(1).asInstanceOf[Seq[Row]]   // get range array of id 
    // initial a resultDF to store each query's results
    val min_value1 = rangArray(0).get(0).asInstanceOf[Int]
    val max_value1 = rangArray(0).get(1).asInstanceOf[Int]
    val s = "value1 between "+min_value1+" and "+ max_value1
    var resultDF = dfa.filter(s).groupBy("id")
                                  .agg(collect_list("value1").as("results"),
                                   min("value1").as("min_value1"),
                                   max("value1").as("max_value1"))
    for( i <-1 to timePairArr.length-1){
       val temp_min_value1 = rangArray(0).get(0).asInstanceOf[Int]
       val temp_max_value1 = rangArray(0).get(1).asInstanceOf[Int]
       val query = "value1 between "+temp_min_value1+" and "+ temp_max_value1
       val tempResultDF = dfa.filter(query).groupBy("id")
                                  .agg(collect_list("value1").as("results"),
                                   min("value1").as("min_value1"),
                                   max("value1").as("max_value1"))
       resultDF = resultDF.union(tempResultDF)
       }

  return resultDF
}

def myFunction():DataFrame = {
  val dfA = spark.read.parquet(routeA)
  val dfB = spark.read.parquet(routeB)

  val idArrays = dfB.select("id").distinct().collect()
  // initial result
  var resultDF = tempFunction(idArrays(0).get(0).asInstanceOf[Int],dfA,dfB)
   //tranverse all id 
  for(i<-1 to idArrays.length-1){  
     val tempDF = tempFunction(idArrays(i).get(0).asInstanceOf[Int],dfA,dfB)
     resultDF = resultDF.union(tempDF)
  }
  return resultDF
}

我尝试过我的算法,耗时近50个小时。

有没有人有更有效的方法?非常感谢。

2 个答案:

答案 0 :(得分:1)

假设您的DFB是小型数据集,我试图给出以下解决方案。

尝试使用下面的Broadcast Join

import org.apache.spark.sql.functions.broadcast

dfA.join(broadcast(dfB), col("dfA.id") === col("dfB.id") && col("dfA.value1") >= col("dfB.min_value1") && col("dfA.value1") <= col("dfB.max_value1")).groupBy(col("dfA.id")).agg(collect_list(struct("value2").as("results"));

BroadcastJoin就像一个Map Side Join。这将为所有映射器实现较小的数据。这将通过在减少步骤期间省略所需的排序和混洗阶段来改善性能。

我希望你避免一些观点:

永远不要使用collect()。在RDD上发出收集操作时,数据集将复制到驱动程序。

如果您的数据太大,您可能会遇到内存超出范围的异常。

请尝试使用take()takeSample()

答案 1 :(得分:0)

很明显当两个数据帧/数据集参与计算时,应该执行连接。所以加入对你来说是必须的。但是你什么时候加入是一个重要的问题。

我建议在加入之前尽可能多地聚合和减少数据帧中的行,因为它会减少混乱

在你的情况下你可以只减少dfA,因为你需要精确的dfB,并且dfA符合条件添加了一列

所以你可以groupBy id 并聚合 dfA ,这样你就可以获得每一行 id 的一行,然后就可以执行加入。然后,您可以使用udf函数作为计算逻辑

提供的评论是为了清晰和解释

import org.apache.spark.sql.functions._
//udf function to filter only the collected value2 which has value1 within range of min_value1 and max_value1 
def selectRangedValue2Udf = udf((minValue: Int, maxValue: Int, list: Seq[Row])=> list.filter(row => row.getAs[Int]("value1") <= maxValue && row.getAs[Int]("value1") >= minValue).map(_.getAs[Int]("value2")))


dfA.groupBy("id")              //grouping by id
  .agg(collect_list(struct("value1", "value2")).as("collection"))  //collecting all the value1 and value2 as structs
  .join(dfB, Seq("id"), "right")          //joining both dataframes with id
  .select(col("id"), col("min_value1"), col("max_value1"), selectRangedValue2Udf(col("min_value1"), col("max_value1"), col("collection")).as("results"))  //calling the udf function defined above

应该给你

+---+----------+----------+------------------------------+
|id |min_value1|max_value1|results                       |
+---+----------+----------+------------------------------+
|1  |0         |3         |[4345, 3434, 4676, 3454]      |
|1  |5         |9         |[5778, 5674, 3456, 6590, 5461]|
|2  |1         |4         |[2343, 4946, 4353, 4354]      |
|2  |6         |8         |[8695, 6587, 5688]            |
+---+----------+----------+------------------------------+

我希望答案很有帮助