Spark Scala:连续几个月计算

时间:2017-10-17 22:44:26

标签: scala apache-spark spark-dataframe

我有以下DataFrame示例:

Provider  Patient  Date
Smith     John     2016-01-23
Smith     John     2016-02-20
Smith     John     2016-03-21
Smith     John     2016-06-25
Smith     Jill     2016-02-01
Smith     Jill     2016-03-10
James     Jill     2017-04-10
James     Jill     2017-05-11

我想以编程方式添加一列,指示患者看病的连续月数。新DataFrame看起来像这样:

Provider  Patient  Date         consecutive_id
Smith     John     2016-01-23   3
Smith     John     2016-02-20   3
Smith     John     2016-03-21   3
Smith     John     2016-06-25   1
Smith     Jill     2016-02-01   2
Smith     Jill     2016-03-10   2
James     Jill     2017-04-10   2
James     Jill     2017-05-11   2

我假设有一种方法可以通过Window功能实现这一目标,但我还没有能够弄清楚它并且我期待着社区可以提供的洞察力。感谢。

2 个答案:

答案 0 :(得分:1)

至少有3种方法可以获得结果

  1. 在SQL中实现逻辑
  2. 使用Spark API进行窗口函数 - .over(windowSpec)
  3. 直接使用.rdd.mapPartitions
  4. Introducing Window Functions in Spark SQL

    对于所有解决方案,您可以调用.toDebugString来查看引擎下的操作。

    SQL解决方案在

    之下
    val my_df = List(
      ("Smith", "John", "2016-01-23"),
      ("Smith", "John", "2016-02-20"),
      ("Smith", "John", "2016-03-21"),
      ("Smith", "John", "2016-06-25"),
      ("Smith", "Jill", "2016-02-01"),
      ("Smith", "Jill", "2016-03-10"),
      ("James", "Jill", "2017-04-10"),
      ("James", "Jill", "2017-05-11")
      ).toDF(Seq("Provider", "Patient", "Date"): _*)
    
    my_df.createOrReplaceTempView("tbl")
    
    val q = """
    select t2.*, count(*) over (partition by provider, patient, grp) consecutive_id
      from (select t1.*, sum(x) over (partition by provider, patient order by yyyymm) grp
              from (select t0.*,
                           case
                              when cast(yyyymm as int) - 
                                   cast(lag(yyyymm) over (partition by provider, patient order by yyyymm) as int) = 1
                              then 0
                              else 1
                           end x
                      from (select tbl.*, substr(translate(date, '-', ''), 1, 6) yyyymm from tbl) t0) t1) t2
    """
    
    sql(q).show
    sql(q).rdd.toDebugString
    

    输出

    scala> sql(q).show
    +--------+-------+----------+------+---+---+--------------+
    |Provider|Patient|      Date|yyyymm|  x|grp|consecutive_id|
    +--------+-------+----------+------+---+---+--------------+
    |   Smith|   Jill|2016-02-01|201602|  1|  1|             2|
    |   Smith|   Jill|2016-03-10|201603|  0|  1|             2|
    |   James|   Jill|2017-04-10|201704|  1|  1|             2|
    |   James|   Jill|2017-05-11|201705|  0|  1|             2|
    |   Smith|   John|2016-01-23|201601|  1|  1|             3|
    |   Smith|   John|2016-02-20|201602|  0|  1|             3|
    |   Smith|   John|2016-03-21|201603|  0|  1|             3|
    |   Smith|   John|2016-06-25|201606|  1|  2|             1|
    +--------+-------+----------+------+---+---+--------------+
    

    <强>更新

    .mapPartitions + .over(windowSpec)的混合

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType}
    
    val schema = new StructType().add(
                 StructField("provider", StringType, true)).add(
                 StructField("patient", StringType, true)).add(
                 StructField("date", StringType, true)).add(
                 StructField("x", IntegerType, true)).add(
                 StructField("grp", IntegerType, true))
    
    def f(iter: Iterator[Row]) : Iterator[Row] = {
      iter.scanLeft(Row("_", "_", "000000", 0, 0))
      {
        case (x1, x2) =>
    
        val x = 
        if (x2.getString(2).replaceAll("-", "").substring(0, 6).toInt ==
            x1.getString(2).replaceAll("-", "").substring(0, 6).toInt + 1) 
        (0) else (1);
    
        val grp = x1.getInt(4) + x;
    
        Row(x2.getString(0), x2.getString(1), x2.getString(2), x, grp);
      }.drop(1)
    }
    
    val df_mod = spark.createDataFrame(my_df.repartition($"provider", $"patient")
                                            .sortWithinPartitions($"date")
                                            .rdd.mapPartitions(f, true), schema)
    
    import org.apache.spark.sql.expressions.Window
    val windowSpec = Window.partitionBy($"provider", $"patient", $"grp")
    df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec)
         ).orderBy($"provider", $"patient", $"date").show
    

    输出

    scala> df_mod.withColumn("consecutive_id", count(lit("1")).over(windowSpec)
         |      ).orderBy($"provider", $"patient", $"date").show
    +--------+-------+----------+---+---+--------------+
    |provider|patient|      date|  x|grp|consecutive_id|
    +--------+-------+----------+---+---+--------------+
    |   James|   Jill|2017-04-10|  1|  1|             2|
    |   James|   Jill|2017-05-11|  0|  1|             2|
    |   Smith|   Jill|2016-02-01|  1|  1|             2|
    |   Smith|   Jill|2016-03-10|  0|  1|             2|
    |   Smith|   John|2016-01-23|  1|  1|             3|
    |   Smith|   John|2016-02-20|  0|  1|             3|
    |   Smith|   John|2016-03-21|  0|  1|             3|
    |   Smith|   John|2016-06-25|  1|  2|             1|
    +--------+-------+----------+---+---+--------------+
    

答案 1 :(得分:0)

你可以:

  1. 将日期重新格式化为整数(2016-01 = 1, 2016-02 = 2, 2017-01 = 13 ...等)
  2. 将所有日期合并到一个带有窗口和collect_list的数组中:

    val winSpec = Window.partitionBy("Provider","Patient").orderBy("Date") df.withColumn("Dates", collect_list("Date").over(winSpec))

  3. 将数组作为带有spark.udf.register的UDF传递到@marios solution的修改版本,以获得最大连续月数