Spark Dataframe GroupBy&复杂的案例陈述推导

时间:2017-11-17 15:36:40

标签: scala apache-spark pyspark apache-spark-sql spark-dataframe

我需要使用Spark Scala在数据帧上实现以下场景:

Scenarios-1: If the "KEY" exist one time, take the "TYPE_VAL" as is . 
             Eg: KEY=66 exist once so take the TYPE_VAL=100
Scenarios-2: If the "KEY" exist more than one time, Check for the same TYPE_VAL, if it is same, then take TYPE_VAL once . 
             Eg: for KEY=68,so TYPE_VAL=23 
Scenarios-3: If the "KEY" exist more than one time, Check for the same TYPE_VAL and subtract the other TYPE_VAL. 
             Eg: for KEY=67 , TYPE_VAL=10 exists twice,so subtract 2 & 4 from 10, finally TYPE_VAL=4

我尝试将group by用于相同的密钥,但无法导出所有方案

   //Sample Input Values
    val values = List(List("66","100") ,
    List("67","10") , List("67","10"),List("67","2"),List("67","4")
    List("68","23"),List("68","23")).map(x =>(x(0), x(1)))

    import spark.implicits._
    //created a dataframe
    val df1 = values.toDF("KEY","TYPE_VAL")

    df1.show(false)
    ------------------------
    KEY |TYPE_VAL  |
    ------------------------
    66  |100       |
    67  |10        |
    67  |10        |
    67  |2         |
    67  |4         |
    68  |23        |
    68  |23        |
    -------------------------

预期产出:

df2.show(false)
    ------------------------
    KEY |TYPE_VAL  |
    ------------------------
    66  |100       | -------> [single row ,so 100]
    67  |4         | -------> [four rows,out of which two are same & rest are diffrent, so (10 - 2 - 4) = 4 ]
    68  |23        | -------> [two rows with same values, so 23]
    -------------------------

2 个答案:

答案 0 :(得分:1)

如果,您可以假设每个键的记录数不能太大(即最多〜数千?),您可以在分组后使用collect_list将所有匹配项放入数组中,然后使用UDF根据该数组计算结果:

import org.apache.spark.sql.functions._
import spark.implicits._

// create the sample data:
val df1 = List(
  (66, 100),
  (67, 10),
  (67, 10),
  (67, 2),
  (67, 4),
  (68, 23),
  (68, 23)
).toDF("KEY", "TYPE_VAL")

// define a UDF that computes the result per scenario for a given Seq[Int]. 
// This is just one possible implementation, simpler ones probably exist...
val computeTypeVal = udf { (vals: Seq[Int]) =>
  vals.groupBy(identity).values.toList.sortBy(-_.size).flatten match {
    case a :: Nil => a
    case a :: b :: tail if a == b => a - tail.filterNot(_ == a).sum
    case _ => 0 // or whatever else should be done for other cases
  }
}

// group by key, use functions.collect_list to collect all value per key and apply UDF
df1.groupBy($"KEY")
  .agg(collect_list($"TYPE_VAL") as "VALS")
  .select($"KEY", computeTypeVal($"VALS") as "TYPE_VAL")
  .sort($"KEY")
  .show() 

答案 1 :(得分:0)

增强用户Tzach Zohar共享的解决方案,以处理输入列是否具有不同的数据类型,如Int,Double,null

val df1 = List(
  (66, Some("100")),
  (67, Some("10.4")),
  (67, Some("10.4")),
  (67, Some("2")),
  (67, Some("4")),
  (68, Some("23")),
  (68, Some("23")),
  (99, None),
  (999,Some(""))
).toDF("KEY", "TYPE_VAL")

df1.show()
+---+--------+
|KEY|TYPE_VAL|
+---+--------+
| 66|     100|
| 67|    10.4|
| 67|    10.4|
| 67|       2|
| 67|       4|
| 68|      23|
| 68|      23|
| 99|    null|
|999|        |
+---+--------+

所以增强的udf如下:

val computeTypeVal = udf { (vals: Seq[String]) =>
  vals.groupBy(identity).values.toList.sortBy(-_.size).flatten match {
    case a :: Nil => if (a == "") None else Some(a.toDouble) 
    case a :: b :: tail if a == b => Some(a.toDouble - tail.map(_.toDouble).filterNot(_ == a.toDouble).sum)
    case _ => Some(0.00) // or whatever else should be done for other cases
  }
}

df1.groupBy($"KEY").agg(collect_list($"TYPE_VAL") as "VALS").select($"KEY", computeTypeVal($"VALS") as "TYPE_VAL").show()

+---+--------+
|KEY|TYPE_VAL|
+---+--------+
| 68|    23.0|
|999|    null|
| 99|     0.0|
| 66|   100.0|
| 67|     4.4|
+---+--------+