我有以下DataFrame
示例:
Provider Patient Date
Smith John 2016-01-23
Smith John 2016-02-20
Smith John 2016-03-21
Smith John 2016-06-25
Smith Jill 2016-02-01
Smith Jill 2016-03-10
James Jill 2017-04-10
James Jill 2017-05-11
我想以编程方式添加一列,指示患者看病的连续月数。新DataFrame
看起来像这样:
Provider Patient Date consecutive_id
Smith John 2016-01-23 3
Smith John 2016-02-20 3
Smith John 2016-03-21 3
Smith John 2016-06-25 1
Smith Jill 2016-02-01 2
Smith Jill 2016-03-10 2
James Jill 2017-04-10 2
James Jill 2017-05-11 2
我假设有一种方法可以通过Window
功能实现这一目标,但我还没有能够弄清楚它并且我期待着社区可以提供的洞察力。感谢。
答案 0 :(得分:1)
至少有3种方法可以获得结果
Introducing Window Functions in Spark SQL
对于所有解决方案,您可以调用.toDebugString来查看引擎下的操作。
SQL解决方案在
之下val my_df = List(
("Smith", "John", "2016-01-23"),
("Smith", "John", "2016-02-20"),
("Smith", "John", "2016-03-21"),
("Smith", "John", "2016-06-25"),
("Smith", "Jill", "2016-02-01"),
("Smith", "Jill", "2016-03-10"),
("James", "Jill", "2017-04-10"),
("James", "Jill", "2017-05-11")
).toDF(Seq("Provider", "Patient", "Date"): _*)
my_df.createOrReplaceTempView("tbl")
val q = """
select t2.*, count(*) over (partition by provider, patient, grp) consecutive_id
from (select t1.*, sum(x) over (partition by provider, patient order by yyyymm) grp
from (select t0.*,
case
when cast(yyyymm as int) -
cast(lag(yyyymm) over (partition by provider, patient order by yyyymm) as int) = 1
then 0
else 1
end x
from (select tbl.*, substr(translate(date, '-', ''), 1, 6) yyyymm from tbl) t0) t1) t2
"""
sql(q).show
sql(q).rdd.toDebugString
输出
scala> sql(q).show
+--------+-------+----------+------+---+---+--------------+
|Provider|Patient| Date|yyyymm| x|grp|consecutive_id|
+--------+-------+----------+------+---+---+--------------+
| Smith| Jill|2016-02-01|201602| 1| 1| 2|
| Smith| Jill|2016-03-10|201603| 0| 1| 2|
| James| Jill|2017-04-10|201704| 1| 1| 2|
| James| Jill|2017-05-11|201705| 0| 1| 2|
| Smith| John|2016-01-23|201601| 1| 1| 3|
| Smith| John|2016-02-20|201602| 0| 1| 3|
| Smith| John|2016-03-21|201603| 0| 1| 3|
| Smith| John|2016-06-25|201606| 1| 2| 1|
+--------+-------+----------+------+---+---+--------------+
<强>更新强>
.mapPartitions + .over(windowSpec)的混合
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType}
val schema = new StructType().add(
StructField("provider", StringType, true)).add(
StructField("patient", StringType, true)).add(
StructField("date", StringType, true)).add(
StructField("x", IntegerType, true)).add(
StructField("grp", IntegerType, true))
def f(iter: Iterator[Row]) : Iterator[Row] = {
iter.scanLeft(Row("_", "_", "000000", 0, 0))
{
case (x1, x2) =>
val x =
if (x2.getString(2).replaceAll("-", "").substring(0, 6).toInt ==
x1.getString(2).replaceAll("-", "").substring(0, 6).toInt + 1)
(0) else (1);
val grp = x1.getInt(4) + x;
Row(x2.getString(0), x2.getString(1), x2.getString(2), x, grp);
}.drop(1)
}
val df_mod = spark.createDataFrame(my_df.repartition($"provider", $"patient")
.sortWithinPartitions($"date")
.rdd.mapPartitions(f, true), schema)
import org.apache.spark.sql.expressions.Window
val windowSpec = Window.partitionBy($"provider", $"patient", $"grp")
df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec)
).orderBy($"provider", $"patient", $"date").show
输出
scala> df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec)
| ).orderBy($"provider", $"patient", $"date").show
+--------+-------+----------+---+---+--------------+
|provider|patient| date| x|grp|consecutive_id|
+--------+-------+----------+---+---+--------------+
| James| Jill|2017-04-10| 1| 1| 2|
| James| Jill|2017-05-11| 0| 1| 2|
| Smith| Jill|2016-02-01| 1| 1| 2|
| Smith| Jill|2016-03-10| 0| 1| 2|
| Smith| John|2016-01-23| 1| 1| 3|
| Smith| John|2016-02-20| 0| 1| 3|
| Smith| John|2016-03-21| 0| 1| 3|
| Smith| John|2016-06-25| 1| 2| 1|
+--------+-------+----------+---+---+--------------+
答案 1 :(得分:0)
你可以:
2016-01 = 1, 2016-02 = 2, 2017-01 = 13
...等)将所有日期合并到一个带有窗口和collect_list的数组中:
val winSpec = Window.partitionBy("Provider","Patient").orderBy("Date")
df.withColumn("Dates", collect_list("Date").over(winSpec))
将数组作为带有spark.udf.register
的UDF传递到@marios solution的修改版本,以获得最大连续月数