描述问题的最佳方法是给出一个输入示例以及我想要输出的内容。
输入
--------------------
|id|timestamp |count|
| 1|2017-06-22| 1 |
| 1|2017-06-23| 0 |
| 1|2017-06-24| 1 |
| 2|2017-06-22| 0 |
| 2|2017-06-23| 1 |
逻辑类似于,if(计数中1
的总数在过去Y
天内等于或高于X
code = True
其他
code = False
让我们说X = 5
和Y = 2
然后输出应该是
输出
---------------------
id | code |
1 | True |
2 | False |
输入为SparkSQL
dataframe
(org.apache.spark.sql.DataFrame
)
听起来不是一个非常复杂的问题,但我仍然坚持第一步。我只是设法在dataframe
加载数据!
有什么想法吗?
答案 0 :(得分:1)
根据您的要求,UDAF
aggregation
最适合您。您可以结帐databricks和ragrawal以便更好地了解。
我根据自己的理解为您提供指导,希望对您有所帮助
首先,您需要定义UDAF
。成功阅读上述链接后,您就可以这样做。
private class ManosAggregateFunction(daysToCheck: Int, countsToCheck: Int) extends UserDefinedAggregateFunction {
var referenceDate: String = _
def inputSchema: StructType = new StructType().add("timestamp", StringType).add("count", IntegerType)
// the aggregation buffer can also have multiple values in general but
// this one just has one: the partial sum
def bufferSchema: StructType = new StructType().add("timestamp", StringType).add("count", IntegerType).add("days", IntegerType)
// returns just a double: the sum
def dataType: DataType = BooleanType
// always gets the same result
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, "")
buffer.update(1, 0)
buffer.update(2, 0)
referenceDate = ""
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val nowDate = input.getString(0)
val count = input.getInt(1)
buffer.update(0, nowDate)
buffer.update(1, count)
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")
val previousDate = buffer1.getString(0)
val nowDate = buffer2.getString(0)
if(previousDate != "") {
val oldDate = LocalDate.parse(previousDate, formatter)
val newDate = LocalDate.parse(nowDate, formatter)
buffer1.update(2, buffer1.getInt(2)+(oldDate.toEpochDay() - newDate.toEpochDay()).toInt)
}
buffer1.update(0, buffer2.getString(0))
if(buffer1.getInt(2) < daysToCheck) {
buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1))
}
}
def evaluate(buffer: Row): Any = {
countsToCheck <= buffer.getInt(1)
}
}
在上面的UDAF
中,daysToCheck
和countsToCheck
是您问题中的X
和Y
。
您可以按以下方式调用已定义的UDAF
val manosAgg = new ManosAggregateFunction(5,2)
df.orderBy($"timestamp".desc).groupBy("id").agg(manosAgg(col("timestamp"), col("count")).as("code")).show
最终输出
+---+-----+
| id| code|
+---+-----+
| 1| true|
| 2|false|
+---+-----+
给定输入
val df = Seq(
(1, "2017-06-22", 1),
(1, "2017-06-23", 0),
(1, "2017-06-24", 1),
(2, "2017-06-28", 0),
(2, "2017-06-29", 1)
).toDF("id","timestamp","count")
+---+----------+-----+
|id |timestamp |count|
+---+----------+-----+
|1 |2017-06-22|1 |
|1 |2017-06-23|0 |
|1 |2017-06-24|1 |
|2 |2017-06-28|0 |
|2 |2017-06-29|1 |
+---+----------+-----+
我希望你能解决你的问题。 :)