我想在一个非常大的数据集中找到连续的时间戳。这需要使用Java在Spark中完成(Scala中的代码示例也非常受欢迎)。
每行如下所示:
ID,开始时间,结束时间
例如数据集:
[[1, 10, 15],[1, 15, 20], [2, 10, 13], [1, 22, 33], [2, 13, 16]]
预期结果是具有相同id的所有连续时间范围,对于每个连续时间范围,只有开始时间和结束时间:
[[1, 10, 20],[1, 22, 33], [2, 10, 16]]
我已经尝试了以下方法,但是由于无法维持顺序,因此无法解决。因此,我希望有一种更有效的方法
textFile.mapToPair(x -> new Tuple2<>(x[0],new Tuple2<>(x[1], x[2])
.mapValues(x -> new LinkedList<>(Arrays.asList(x)))
.reduceByKey((x,y) -> {
Tuple2<Long, Long> v1 = x.getLast();
Tuple2<Long, Long> v2 = y.getFirst();
Tuple2<Long, Long> v3 = v2;
if(v2._1().equals(v1._2())) {
v3 = new Tuple2<>(v1._1(), v2._2());
x.removeLast();
}
x.addLast(v3);
return x;
})
.flatMapValues(x -> x);
答案 0 :(得分:1)
我认为这不是一个Spark问题,而是合乎逻辑的问题。 您应该考虑使用几个独立功能的选择:
bindEntries()
)insertEntry()
)建议,我们有模拟数据mockData
:
+---+-----+---+
| id|start|end|
+---+-----+---+
| 1| 22| 33|
| 1| 15| 20|
| 1| 10| 15|
| 2| 13| 16|
| 2| 10| 13|
+---+-----+---+
借助这些功能,我对您的问题的解决方案将是这样的:
val processed = mockData
.groupByKey(_.id)
.flatMapGroups { (id: Int, it: Iterator[Entry]) =>
processEntries(it)
}
processEntries()
的唯一目标是将每个id的所有条目折叠到非相交间隔的集合中。
这是签名:
def processEntries(it: Iterator[Entry]): List[Entry] =
it.foldLeft(Nil: List[Entry])(insertEntry)
此功能用于从分组条目中逐一获取元素,并将它们也逐一推入累加器。
函数insertEntry()
处理这种插入:
def insertEntry(acc: List[Entry], e: Entry): List[Entry] = acc match {
case Nil => e :: Nil
case a :: as =>
val combined = bindEntries(a, e)
combined match {
case x :: y :: Nil => x :: insertEntry(as, y)
case x :: Nil => insertEntry(as, x)
case _ => a :: as
}
}
bindEntries()
函数应为您处理条目的顺序:
def bindEntries(x: Entry, y: Entry): List[Entry] =
(x.start > y.end, x.end < y.start) match {
case (true, _) => y :: x :: Nil
case (_, true) => x :: y :: Nil
case _ => x.copy(start = x.start min y.start, end = x.end max y.end) :: Nil
}
bindEntries()
将返回正确排序的一两个条目的列表。
这是其背后的想法:
insertEntry()
将在插入过程中为您排序所有条目。
毕竟,结果数据集如下所示:
+---+-----+---+
| id|start|end|
+---+-----+---+
| 1| 10| 20|
| 1| 22| 33|
| 2| 10| 16|
+---+-----+---+
注意:函数insertEntry()
不是尾部递归的。
有一个进行进一步优化的良好起点。
有完整的解决方案:
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
object AdHoc {
Logger.getLogger("org").setLevel(Level.OFF)
def main(args: Array[String]): Unit = {
import spark.implicits._
val processed = mockData
.groupByKey(_.id)
.flatMapGroups { (id, it) =>
processEntries(it)
}
mockData.show()
processed.show()
}
def processEntries(it: Iterator[Entry]): List[Entry] =
it.foldLeft(Nil: List[Entry])(insertEntry)
def insertEntry(acc: List[Entry], e: Entry): List[Entry] = acc match {
case Nil => e :: Nil
case a :: as =>
val combined = bindEntries(a, e)
combined match {
case x :: y :: Nil => x :: insertEntry(as, y)
case x :: Nil => insertEntry(as, x)
case _ => a :: as
}
}
def bindEntries(x: Entry, y: Entry): List[Entry] =
(x.start > y.end, x.end < y.start) match {
case (true, _) => y :: x :: Nil
case (_, true) => x :: y :: Nil
case _ => x.copy(start = x.start min y.start, end = x.end max y.end) :: Nil
}
lazy val mockData: Dataset[Entry] = spark.createDataset(Seq(
Entry(1, 22, 33),
Entry(1, 15, 20),
Entry(1, 10, 15),
Entry(2, 13, 16),
Entry(2, 10, 13)
))
case class Entry(id: Int, start: Int, end: Int)
implicit lazy val entryEncoder: Encoder[Entry] = Encoders.product[Entry]
lazy val spark: SparkSession = SparkSession.builder()
.master("local")
.getOrCreate()
}