如何使用Scala / Spark查找每个ID的最大连续年份

时间:2016-04-21 18:56:58

标签: scala apache-spark

我对每一行都有一定的ID以及相应的操作年限:

示例:

ID   YEAR

A1   1999
A2   2000
A1   2000
B1   1998
A1   2002

现在,我需要确定每个ID的连续年数 结果,

A1 : 2  because[1999, 2000 ] 

等,

3 个答案:

答案 0 :(得分:2)

如果您不想使用Spark SQL(在我看来,这对任务来说太过分了),您可以简单地使用 groupByKey (可能的年数)每个id是合理的)

val rdd = sc.parallelize(Seq(
  ("A1", 1999),
  ("A2", 2000),
  ("A1", 2000),
  ("A1", 1998),
  ("A1", 2002),
  ("B1", 1998)
))

def findMaxRange(l: Iterable[Int]) = {
  val ranges = mutable.ArrayBuffer[Int](1)
  l.toSeq.sorted.distinct.sliding(2).foreach { case y1 :: tail =>
    if (tail.nonEmpty) {
      val y2 = tail.head
      if (y2 - y1 == 1)  ranges(ranges.size - 1) +=  1
      else ranges += 1
    }
  }
  ranges.max
}

rdd1.groupByKey.map(r => (r._1, findMaxRange(r._2))).collect()

res7: Array[(String, Int)] = Array((A1,3), (A2,1), (B1,1))

答案 1 :(得分:1)

如果您想要Spark解决方案,我会选择DataFrame。它变得混乱,但这是一个有趣的问题:

val testDf = Seq(
  ("A1", 1999),
  ("A2", 2000),
  ("A1", 2000),
  ("A1", 1998),
  ("A1", 2002),
  ("B1", 1998)
).toDF("ID", "YEAR")

然后我会进行自我加入(实际上是两个中的第一个):

val selfJoined = testDf.orderBy($"YEAR").join(
  testDf.orderBy($"YEAR").toDF("R_ID", "R_YEAR"),
  $"R_ID" === $"ID" && $"YEAR" === ($"R_YEAR" - 1),
  "full_outer"
).filter($"ID".isNull || $"R_ID".isNull)

selfJoined.show
+----+----+----+------+
|  ID|YEAR|R_ID|R_YEAR|
+----+----+----+------+
|null|null|  A2|  2000|
|  A2|2000|null|  null|
|null|null|  B1|  1998|
|  B1|1998|null|  null|
|null|null|  A1|  1998|
|  A1|2000|null|  null|
|null|null|  A1|  2002|
|  A1|2002|null|  null|
+----+----+----+------+

从上面可以看出,我们现在有连续几年的开始和结束日期。 R_YEAR,当不是null时,包含连续几年“运行”的开始。下一行,YEAR是这一年的结束。如果我更擅长Window功能,我可能会使用lag将记录拼接在一起,但我不是这样,我不会。我会做另一个自我加入,然后是groupBy,然后是select中的一些数学,然后是另一个groupBy

selfJoined.filter($"ID".isNull).as("a").join(
  selfJoined.filter($"R_ID".isNull).as("b"),
  $"a.R_ID" === $"b.ID" && $"a.R_YEAR" <= $"b.YEAR"
).groupBy($"a.R_ID", $"a.R_YEAR").agg(min($"b.YEAR") as "last_YEAR")
 .select($"R_ID" as "ID", $"last_YEAR" - $"R_YEAR" + 1 as "inarow")
 .groupBy($"ID").agg(max($"inarow") as "MAX").show
+---+---+
| ID|MAX|
+---+---+
| B1|  1|
| A1|  3|
| A2|  1|
+---+---+

Wheee!

答案 2 :(得分:0)

我会尝试这些方法:

scala> case class DataRow(id: String, year: Int)
defined class DataRow
scala> val data = Seq(
           DataRow("A1", 1999),
           DataRow("A2", 2000),
           DataRow("A1", 2000),
           DataRow("B1", 1998),
           DataRow("A1", 2002)
         )
data: Seq[DataRow] = List(DataRow("A1", 1999), DataRow("A2", 2000), DataRow("A1", 2000), DataRow("B1", 1998), DataRow("A1", 2002))
scala> data.groupBy(_.id).mapValues { rows =>
           val years = rows.map(_.year)
           val firstYear = years.head
           years.zipWithIndex.takeWhile { case (y, i) => y == firstYear + i }.size
         }
res1: Map[String, Int] = Map("B1" -> 1, "A2" -> 1, "A1" -> 2)

这计算每个ID的最大连续年数,假设它看到的第一年是最早的罢工日期。在.sorted行中插入val years,情况并非如此。