我有如下的dataFrame,并希望使用Scala基于列值添加备注

时间:2018-04-11 17:02:44

标签: scala apache-spark

以下是我的输入

id    val  visits  date
111   2        1   20160122
111   2        1   20170122
112   4        2   20160122
112   5        4   20150122
113   6        1   20100120
114   8        2   20150122
114   8        2   20150122

预期产出:

id    val  visits  date        remarks
111   2        1   20160122    oldDate
111   2        1   20170122    recentdate
112   4        2   20160122    less
112   5        4   20150122    more
113   6        1   20100120    one
114   8        2   20150122    Ramdom
114   8        2   20150122    Ramdom

备注应该是: Ramdom for Id有两个具有相同价值和记录的记录。访问&日期 Id的一次访问只有一条记录,其中包含任何访问次数 较少访问ID有两个记录,访问次数少于其他记录 更多访问ID有多个不同价值和访问记录。 recentdate Id有更多具有相同价值和记录的记录。访问次数和最大日期的不同日期 oldDatedate Id具有更多具有相同值和记录的记录。访问次数和最短日期的不同日期

代码:

val grouped = df.groupBy("id").agg(max($"val").as("maxVal"), max($"visits").as("maxVisits"), min($"val").as("minVal"), min($"visits").as("minVisits"), count($"id").as("count"))

val remarks = functions.udf ((value: Int, visits: Int, maxValue: Int, maxVisits: Int, minValue: Int, minVisits: Int, count: Int) =>
   if (count == 1) {
     "One Visit"
   }else if (value == maxValue && value == minValue && visits == maxVisits && visits == minVisits) {
     "Random"
   }else {
     if (visits < maxVisits) {
       "Less Visits"
     }else {
       "More Visits"
     }
   }
 )



df.join(grouped, Seq("id"))
   .withColumn("remarks", remarks($"val", $"visits", $"maxVal", $"maxVisits", $"minVal", $"minVisits", $"count"))
   .drop("maxVal","maxVisits", "minVal", "minVisits", "count")

1 个答案:

答案 0 :(得分:0)

下面的代码应该适合你(但是效率不高,因为有很多,如果有的话)

import org.apache.spark.sql.functions._
def remarkUdf = udf((column: Seq[Row])=>{
  if(column.size == 1) Seq(remarks(column(0).getAs(0), column(0).getAs(1), column(0).getAs(2), "one"))
  else if(column.size == 2) {
    if(column(0) == column(1)) column.map(x => remarks(x.getAs(0), x.getAs(1), x.getAs(2), "Random"))
    else{
      if(column(0).getAs(0) == column(1).getAs(0) && column(0).getAs(1) == column(1).getAs(1)){
        if(column(0).getAs[Int](2) < column(1).getAs[Int](2)) Seq(remarks(column(0).getAs(0), column(0).getAs(1), column(0).getAs(2), "oldDate"), remarks(column(1).getAs(0), column(1).getAs(1), column(1).getAs(2), "recentdate"))
        else Seq(remarks(column(0).getAs(0), column(0).getAs(1), column(0).getAs(2), "recentdate"), remarks(column(1).getAs(0), column(1).getAs(1), column(1).getAs(2), "oldDate"))
      }
      else{
        if(column(0).getAs[Int](0) < column(1).getAs[Int](0) && column(0).getAs[Int](1) < column(1).getAs[Int](1)) {
          Seq(remarks(column(0).getAs(0), column(0).getAs(1), column(0).getAs(2), "less"), remarks(column(1).getAs(0), column(1).getAs(1), column(1).getAs(2), "more"))
        }
        else Seq(remarks(column(0).getAs(0), column(0).getAs(1), column(0).getAs(2), "more"), remarks(column(1).getAs(0), column(1).getAs(1), column(1).getAs(2), "less"))
      }
    }
  }
  else{
    column.map(x => remarks(x.getAs(0), x.getAs(1), x.getAs(2), "not defined"))
  }

})

df.groupBy("id").agg(collect_list(struct("val", "visits", "date")).as("value"))
  .withColumn("value", explode(remarkUdf(col("value"))))
  .select(col("id"), col("value.*"))
  .show(false)

它应该给你

+---+-----+------+--------+----------+
|id |value|Visits|date    |Remarks   |
+---+-----+------+--------+----------+
|111|2    |1     |20160122|oldDate   |
|111|2    |1     |20170122|recentdate|
|112|4    |2     |20160122|less      |
|112|5    |4     |20150122|more      |
|114|8    |2     |20150122|Random    |
|114|8    |2     |20150122|Random    |
|113|6    |1     |20100120|one       |
+---+-----+------+--------+----------+

您需要以下case class

case class remarks(value: Int, Visits: Int, date: Int, Remarks: String)