在没有硬编码的情况下计算“t”时间段的累积计数

时间:2016-09-01 19:51:21

标签: scala dataframe

我想计算不同时间步的累积计数。我对每个时间段t期间发生的事件进行了统计:现在我希望事件的累计数量达到并包括该时间段。

我可以轻松地分别计算每个累积量,但这很乏味。我可以用UnionAll将它们追加到一起,但这也很繁琐,时间段很长。

我怎么能更干净地做到这一点?

package main.scala

import java.io.File
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions._

object Test {

    def main(args: Array[String]) {

        // Spark and SQL Context (gives access to Spark and Spark SQL libraries)
        val conf = new SparkConf().setAppName("Merger")
        val sc = new SparkContext(conf)
        val sqlContext = SQLContextSingleton.getInstance(sc)
        import sqlContext.implicits._

        // Count
        val count = Seq(("A",1,1),("A",1,2),("A",0,3),("A",0,4),("A",0,5),("A",1,6),
                        ("B",1,1),("B",0,2),("B",0,3),("B",1,4),("B",0,5),("B",1,6))
            .toDF("id","count","t")

        val count2 = count.filter('t <= 2).groupBy('id).agg(sum("count"), max("t"))

        val count3 = count.filter('t <= 3).groupBy('id).agg(sum("count"), max("t"))

        count.show()
        count2.show()
        count3.show()
    }
}

count

+---+-----+---+
| id|count|  t|
+---+-----+---+
|  A|    1|  1|
|  A|    1|  2|
|  A|    0|  3|
|  A|    0|  4|
|  A|    0|  5|
|  A|    1|  6|
|  B|    1|  1|
|  B|    0|  2|
|  B|    0|  3|
|  B|    1|  4|
|  B|    0|  5|
|  B|    1|  6|
+---+-----+---+   

count2

+---+----------+------+
| id|sum(count)|max(t)|
+---+----------+------+
|  A|         2|     2|
|  B|         1|     2|
+---+----------+------+

count3

+---+----------+------+
| id|sum(count)|max(t)|
+---+----------+------+
|  A|         2|     3|
|  B|         1|     3|
+---+----------+------+

2 个答案:

答案 0 :(得分:0)

我建议您以一种可以累积累积的方式对数据进行反规范化。 此代码也应该很好地扩展(因为驱动程序只有一个集合)。

很抱歉在我的示例中没有使用Dataframe API(我的spark安装有些不可用,所以我无法测试Dataframes):

val count = sc.makeRDD(Seq(("A",1,1),("A",1,2),("A",0,3),("A",0,4),("A",0,5),("A",1,6),
  ("B",1,1),("B",0,2),("B",0,3),("B",1,4),("B",0,5),("B",1,6)))

// this is required only if number of timesteps is not known, this is the only operation that collects data to driver, and could even be broadcasted if large
val distinctTimesteps = count.map(_._3).distinct().sortBy(e => e, true).collect()

// this actually de-normalizes data so that it can be cumulated
val deNormalizedData = count.flatMap { case (id, c, t) =>
    // the trick is making composite key consisting of distinct timestep and your id: (distTimestep, id)
    distinctTimesteps.filter(distTimestep => distTimestep >= t).map(distTimestep => (distTimestep, id) -> c)
}

// just reduce by composite key and you are done
val cumulativeCounts = deNormalizedData.reduceByKey(_ + _)

// test
cumulativeCounts.collect().foreach(print)

答案 1 :(得分:0)

我已经使用Spark 1.5.2 / Scala 10和Spark 2.0.0 / Scala 11进行了测试,它就像一个魅力。它没有与Spark 1.6.2一起工作,我怀疑是因为它不是用Hive编译的。

package main.scala

import java.io.File
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SQLContext


object Test {

    def main(args: Array[String]) {

        val conf = new SparkConf().setAppName("Test")
        val sc = new SparkContext(conf)
        val sqlContext = SQLContextSingleton.getInstance(sc)
        import sqlContext.implicits._

        val data = Seq(("A",1,1,1),("A",3,1,3),("A",0,0,2),("A",4,0,4),("A",0,0,6),("A",2,1,5),
                         ("B",0,1,3),("B",0,0,4),("B",2,0,1),("B",2,1,2),("B",0,0,6),("B",1,1,5))
            .toDF("id","param1","param2","t")
        data.show()

        data.withColumn("cumulativeSum1", sum("param1").over( Window.partitionBy("id").orderBy("t")))
            .withColumn("cumulativeSum2", sum("param2").over( Window.partitionBy("id").orderBy("t")))
            .show()
    }
}

我正在进行的改进是能够将其同时应用于多个列,而不是重复withColumn。欢迎投入!