将一行拆分为两行并假设某些列

时间:2018-03-01 10:14:50

标签: scala apache-spark apache-spark-sql scala-collections

我需要拆分行并通过更改日期列来创建新行,并将amt列设置为零,如下例所示:

Input:  
+---+-----------------------+-----------------------+-----+
|KEY|START_DATE             |END_DATE               |Amt  |
+---+-----------------------+-----------------------+-----+
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|100.0|
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|200.0|
|0  |2017-10-30T00:00:00.000|2017-11-02T23:59:59.000|67.5 |->Split row based on start & date end date is between "2017-10-31T23:59:59" condition
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|55.3 |
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|22.2 |
|1  |2017-10-30T00:00:00.000|2017-11-01T23:59:59.000|11.0 |->Split row based on start & date end date is between "2017-10-31T23:59:59" condition
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|87.33|
+---+-----------------------+-----------------------+-----+

如果“2017-10-31T23:59:59”位于行start_date和end_date之间,则通过更改一行的end_date和另一行的start_date将行拆分为两行。并将新行的amt设为零,如下所示:

期望的输出:

+---+-----------------------+-----------------------+-----+---+
|KEY|START_DATE             |END_DATE               |Amt  |Ind|
+---+-----------------------+-----------------------+-----+---+
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|100.0|N  |
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|200.0|N  |

|0  |2017-10-30T00:00:00.000|2017-10-30T23:59:59.998|67.5 |N  |->parent row (changed the END_DATE)     
|0  |2017-10-30T23:59:59.999|2017-11-02T23:59:59.000|0.0  |Y  |->splitted new row(changed the START_DATE and Amt=0.0)          

|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|55.3 |N  |     
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|22.2 |N  |

|1  |2017-10-30T00:00:00.000|2017-10-30T23:59:59.998|11.0 |N  |->parent row (changed the END_DATE)    
|1  |2017-10-30T23:59:59.999|2017-11-01T23:59:59.000|0.0  |Y  |->splitted new row(changed the START_DATE and Amt=0.0)     

|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|87.33|N  |     
+---+-----------------------+-----------------------+-----+---+

我已尝试过以下代码并能够复制该行,但无法动态更新行。

val df1Columns = Seq("KEY", "START_DATE", "END_DATE", "Amt")

  val df1Schema = new StructType(df1Columns.map(c => StructField(c, StringType, nullable = false)).toArray)
  val input1: Array[String] = Seq("0", "2016-12-14T23:59:59.000", "2017-10-29T23:59:58.000", "100.0").toArray;
  val row1: Row = Row.fromSeq(input1)
  val input2: Array[String] = Seq("0", "2016-12-14T23:59:59.000", "2017-10-29T23:59:58.000", "200.0").toArray;
  val row2: Row = Row.fromSeq(input2)
  val input3: Array[String] = Seq("0", "2017-10-30T00:00:00.000", "2017-11-0123:59:59.000", "67.5").toArray;
  val row3: Row = Row.fromSeq(input3)
  val input4: Array[String] = Seq("0", "2016-12-14T23:59:59.000", "2017-10-29T23:59:58.000", "55.3").toArray;
  val row4: Row = Row.fromSeq(input4)
  val input5: Array[String] = Seq("1", "2016-12-14T23:59:59.000", "2017-10-29T23:59:58.000", "22.2").toArray;
  val row5: Row = Row.fromSeq(input5)
  val input6: Array[String] = Seq("1", "2017-10-30T00:00:00.000", "2017-11-0123:59:59.000", "11.0").toArray;
  val row6: Row = Row.fromSeq(input6)
  val input7: Array[String] = Seq("1", "2016-12-14T23:59:59.000", "2017-10-29T23:59:58.000", "87.33").toArray;
  val row7: Row = Row.fromSeq(input7)

  val rdd: RDD[Row] = spark.sparkContext.parallelize(Seq(row1, row2, row3, row4, row5, row6, row7))
  val df: DataFrame = spark.createDataFrame(rdd, df1Schema)

  //----------------------------------------------------------------

def encoder(columns: Seq[String]): Encoder[Row] = RowEncoder(StructType(columns.map(StructField(_, StringType, nullable = true))))
val outputColumns = Seq("KEY", "START_DATE", "END_DATE", "Amt","Ind")

  val result = df.groupByKey(r => r.getAs[String]("KEY"))
    .flatMapGroups((_, rowsForAkey) => {
      var result: List[Row] = List()
      for (row <- rowsForAkey) {
        val qrDate = "2017-10-31T23:59:59"
        val currRowStartDate = row.getAs[String]("START_DATE")
        val rowEndDate = row.getAs[String]("END_DATE")
        if (currRowStartDate <= qrDate && qrDate <= rowEndDate) //Quota
        {
          val rLayer = row
          result = result :+ rLayer
        }
        val originalRow = row
        result = result :+ originalRow
      }
      result
      })(encoder(df1Columns)).toDF

  df.show(false)
  result.show(false)

这是我的代码输出:

+---+-----------------------+-----------------------+-----+
|KEY|START_DATE             |END_DATE               |Amt  |
+---+-----------------------+-----------------------+-----+
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|100.0|     
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|200.0|     
|0  |2017-10-30T00:00:00.000|2017-11-0123:59:59.000 |67.5 |     
|0  |2017-10-30T00:00:00.000|2017-11-0123:59:59.000 |67.5 |     
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|55.3 |     
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|22.2 |     
|1  |2017-10-30T00:00:00.000|2017-11-0123:59:59.000 |11.0 |     
|1  |2017-10-30T00:00:00.000|2017-11-0123:59:59.000 |11.0 |     
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:58.000|87.33|     
+---+-----------------------+-----------------------+-----+

2 个答案:

答案 0 :(得分:2)

看起来你正在复制行,而不是改变它们。

您可以使用以下内容替换flatMapGroups函数的内部:

rowsForAKey.flatMap{ row => 
  val qrDate = "2017-10-31T23:59:59"
  val currRowStartDate = row.getAs[String]("START_DATE")
  val rowEndDate = row.getAs[String]("END_DATE")
  if (currRowStartDate <= qrDate && qrDate <= rowEndDate) //Quota
  {
    val splitDate = endOfDay(currRowStartDate)
    // need to build two rows
    val parentRow = Row(row(0), row(1), splitDate, row(3), "Y")
    val splitRow = Row(row(0), splitDate, row(2), 0.0, "N")
    List(parentRow, splitRow)
  }
  else {
    List(row)
  }
}

基本上,只要你有一个for循环在Scala中建立一个这样的列表,它就是你想要的mapflatMap。在这里,它是flatMap,因为每一行都会在结果中给出一个或两个元素。我假设你引入了一个函数endOfDay来制作正确的时间戳。

我意识到您可能正在以一种DataFrame的方式阅读数据,但我确实想提供使用Dataset[Some Case Class]的想法 - 它基本上是一个插件替换(您基本上将DataFrame视为Dataset[Row],毕竟这就是它,并且我认为它会让事情变得更容易阅读,而且您还会进行类型检查。

另外,如果您导入spark.implicits._,则不需要编码器 - 一切看起来都是字符串或浮点数,并且这些编码器可用。

答案 1 :(得分:2)

我建议你使用内置函数,而不是通过这种复杂的 rdd 方式。

我使用了内置函数,例如lit来填充常量,使用udf函数来更改日期列中的时间

主题是将dataframes分成两个,最后union个 (我已经对代码的清晰度进行了评论)

import org.apache.spark.sql.functions._
//udf function to change the time
def changeTimeInDate = udf((toCopy: String, withCopied: String)=> withCopied.split("T")(0)+"T"+toCopy.split("T")(1))

//creating Ind column with N populated and saving in temporaty dataframe
val indDF = df.withColumn("Ind", lit("N"))

//filtering out the rows that match the condition mentioned in the question and then changing the Amt column and Ind column and START_DATE
val duplicatedDF = indDF.filter($"START_DATE" <= "2017-10-31T23:59:59" && $"END_DATE" >= "2017-10-31T23:59:59")
  .withColumn("Amt", lit("0.0"))
  .withColumn("Ind", lit("Y"))
  .withColumn("START_DATE", changeTimeInDate($"END_DATE", $"START_DATE"))

//Changing the END_DATE and finally merging both
val result = indDF.withColumn("END_DATE", changeTimeInDate($"START_DATE", $"END_DATE"))
  .union(duplicatedDF)

您应该有所需的输出

+---+-----------------------+-----------------------+-----+---+
|KEY|START_DATE             |END_DATE               |Amt  |Ind|
+---+-----------------------+-----------------------+-----+---+
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:59.000|100.0|N  |
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:59.000|55.3 |N  |
|0  |2016-12-14T23:59:59.000|2017-10-29T23:59:59.000|200.0|N  |
|0  |2017-10-30T00:00:00.000|2017-11-01T00:00:00.000|67.5 |N  |
|0  |2017-10-30T23:59:59.000|2017-11-01T23:59:59.000|0.0  |Y  |
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:59.000|22.2 |N  |
|1  |2016-12-14T23:59:59.000|2017-10-29T23:59:59.000|87.33|N  |
|1  |2017-10-30T00:00:00.000|2017-11-01T00:00:00.000|11.0 |N  |
|1  |2017-10-30T23:59:59.000|2017-11-01T23:59:59.000|0.0  |Y  |
+---+-----------------------+-----------------------+-----+---+