Spark:数据集上的Mapgroups

时间:2018-03-15 04:05:03

标签: apache-spark spark-dataframe apache-spark-dataset

我在下面的数据集上尝试了这个mapgroups函数 并且不确定为什么我为#34; Total Value"而得到0柱。 我在这里遗漏了什么吗???请建议

Spark版本 - 2.0 Scala版本 - 2.11

case class Record(Hour: Int, Category: String,TotalComm: Double, TotalValue: Int)
val ss = (SparkSession)
import ss.implicits._

val df: DataFrame = ss.sparkContext.parallelize(Seq(
(0, "cat26", 30.9, 200), (0, "cat26", 22.1, 100), (0, "cat95", 19.6, 300), (1, "cat4", 1.3, 100),
(1, "cat23", 28.5, 100), (1, "cat4", 26.8, 400), (1, "cat13", 12.6, 250), (1, "cat23", 5.3, 300),
(0, "cat26", 39.6, 30), (2, "cat40", 29.7, 500), (1, "cat4", 27.9, 600), (2, "cat68", 9.8, 100),
(1, "cat23", 35.6, 500))).toDF("Hour", "Category","TotalComm", "TotalValue")

val resultSum = df.as[Record].map(row => ((row.Hour,row.Category),(row.TotalComm,row.TotalValue)))
.groupByKey(_._1).mapGroups{case(k,iter) => (k._1,k._2,iter.map(x => x._2._1).sum,iter.map(y => y._2._2).sum)}
.toDF("KeyHour","KeyCategory","TotalComm","TotalValue").orderBy(asc("KeyHour"))

resultSum.show()

+-------+-----------+---------+----------+
|KeyHour|KeyCategory|TotalComm|TotalValue|
+-------+-----------+---------+----------+
|      0|      cat26|     92.6|         0|
|      0|      cat95|     19.6|         0|
|      1|      cat13|     12.6|         0|
|      1|      cat23|     69.4|         0|
|      1|       cat4|     56.0|         0|
|      2|      cat40|     29.7|         0|
|      2|      cat68|      9.8|         0|
+-------+-----------+---------+----------+  

2 个答案:

答案 0 :(得分:2)

iter内的mapGroups缓冲区计算只能执行一次。因此,当您总和为iter.map(x => x._2._1).sum然后时, iter 缓冲区中没有任何内容,因此iter.map(y => y._2._2).sum操作会产生 0 。所以你必须找到一种在同一次迭代中计算两者总和的机制

for List with RegisterBuffers

为简单起见,我使用了for循环和ListBuffer来同时对两者进行求和

val resultSum = df.as[Record].map(row => ((row.Hour,row.Category),(row.TotalComm,row.TotalValue)))
  .groupByKey(_._1).mapGroups{case(k,iter) => {
  val listBuffer1 = new ListBuffer[Double]
  val listBuffer2 = new ListBuffer[Int]
      for(a <- iter){
        listBuffer1 += a._2._1
        listBuffer2 += a._2._2
      }
      (k._1, k._2, listBuffer1.sum, listBuffer2.sum)
    }}
  .toDF("KeyHour","KeyCategory","TotalComm","TotalValue").orderBy($"KeyHour".asc)

这应该给你正确的结果

+-------+-----------+---------+----------+
|KeyHour|KeyCategory|TotalComm|TotalValue|
+-------+-----------+---------+----------+
|      0|      cat26|     92.6|       330|
|      0|      cat95|     19.6|       300|
|      1|      cat23|     69.4|       900|
|      1|      cat13|     12.6|       250|
|      1|       cat4|     56.0|      1100|
|      2|      cat68|      9.8|       100|
|      2|      cat40|     29.7|       500|
+-------+-----------+---------+----------+

我希望答案很有帮助

答案 1 :(得分:1)

正如Ramesh Maharjan指出的那样,问题在于使用迭代器两次,这将导致TotalValue列为0.但是,甚至不需要使用groupByKey和{{ 1}}从一开始。使用mapGroupsgroupBy可以完成相同的操作,这将使代码更清晰,更易于阅读。而且,作为一个加号,它也避免使用慢agg

以下内容同样适用:

groupByKey

结果:

val resultSum = df.groupBy($"Hour", $"Category")
  .agg(sum($"TotalComm").as("TotalComm"), sum($"TotalValue").as("TotalValue"))
  .orderBy(asc("Hour"))

如果您仍想通过将+----+--------+---------+----------+ |Hour|Category|TotalComm|TotalValue| +----+--------+---------+----------+ | 0| cat95| 19.6| 300| | 0| cat26| 92.6| 330| | 1| cat23| 69.4| 900| | 1| cat13| 12.6| 250| | 1| cat4| 56.0| 1100| | 2| cat68| 9.8| 100| | 2| cat40| 29.7| 500| +----+--------+---------+----------+ 更改为

来更改可轻松完成的小时和类别列的名称
groupBy