在计算Spark中出现次数后对数据进行分类

时间:2017-06-23 16:09:52

标签: apache-spark apache-spark-sql

描述问题的最佳方法是给出一个输入示例以及我想要输出的内容。

输入

--------------------
|id|timestamp |count|
| 1|2017-06-22|  1  | 
| 1|2017-06-23|  0  |
| 1|2017-06-24|  1  |
| 2|2017-06-22|  0  |
| 2|2017-06-23|  1  |

逻辑类似于,if(计数中1的总数在过去Y天内等于或高于X

code = True 

其他

code = False 

让我们说X = 5Y = 2然后输出应该是

输出

---------------------
id | code  | 
 1 | True  |
 2 | False |

输入为SparkSQL dataframeorg.apache.spark.sql.DataFrame

听起来不是一个非常复杂的问题,但我仍然坚持第一步。我只是设法在dataframe加载数据!

有什么想法吗?

1 个答案:

答案 0 :(得分:1)

根据您的要求,UDAF aggregation最适合您。您可以结帐databricksragrawal以便更好地了解。

我根据自己的理解为您提供指导,希望对您有所帮助

首先,您需要定义UDAF。成功阅读上述链接后,您就可以这样做。

private class ManosAggregateFunction(daysToCheck: Int, countsToCheck: Int) extends UserDefinedAggregateFunction {

  var referenceDate: String = _
  def inputSchema: StructType = new StructType().add("timestamp", StringType).add("count", IntegerType)
  // the aggregation buffer can also have multiple values in general but
  // this one just has one: the partial sum
  def bufferSchema: StructType = new StructType().add("timestamp", StringType).add("count", IntegerType).add("days", IntegerType)
  // returns just a double: the sum
  def dataType: DataType = BooleanType
  // always gets the same result
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, "")
    buffer.update(1, 0)
    buffer.update(2, 0)
    referenceDate = ""
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val nowDate = input.getString(0)
    val count = input.getInt(1)

    buffer.update(0, nowDate)
    buffer.update(1, count)
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")
    val previousDate = buffer1.getString(0)
    val nowDate = buffer2.getString(0)
    if(previousDate != "") {
      val oldDate = LocalDate.parse(previousDate, formatter)
      val newDate = LocalDate.parse(nowDate, formatter)
      buffer1.update(2, buffer1.getInt(2)+(oldDate.toEpochDay() - newDate.toEpochDay()).toInt)
    }
    buffer1.update(0, buffer2.getString(0))
    if(buffer1.getInt(2) < daysToCheck) {
      buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1))
    }
  }

  def evaluate(buffer: Row): Any = {
    countsToCheck <= buffer.getInt(1)
  }
}

在上面的UDAF中,daysToCheckcountsToCheck是您问题中的XY

您可以按以下方式调用已定义的UDAF

    val manosAgg = new ManosAggregateFunction(5,2)
    df.orderBy($"timestamp".desc).groupBy("id").agg(manosAgg(col("timestamp"), col("count")).as("code")).show

最终输出

+---+-----+
| id| code|
+---+-----+
|  1| true|
|  2|false|
+---+-----+

给定输入

val df = Seq(
  (1, "2017-06-22", 1),
  (1, "2017-06-23", 0),
  (1, "2017-06-24", 1),
  (2, "2017-06-28", 0),
  (2, "2017-06-29", 1)
).toDF("id","timestamp","count")
+---+----------+-----+
|id |timestamp |count|
+---+----------+-----+
|1  |2017-06-22|1    |
|1  |2017-06-23|0    |
|1  |2017-06-24|1    |
|2  |2017-06-28|0    |
|2  |2017-06-29|1    |
+---+----------+-----+

我希望你能解决你的问题。 :)