I have been scratching my head trying to come up with a way to reduce a dataframe in spark to a frame which records gaps in the dataframe, preferably without completely killing parallelism. Here is a much-simplified example (It's a bit lengthy because I wanted it to be able to run):
import org.apache.spark.sql.SparkSession
case class Record(typ: String, start: Int, end: Int);
object Sample {
def main(argv: Array[String]): Unit = {
val sparkSession = SparkSession.builder()
.master("local")
.getOrCreate();
val df = sparkSession.createDataFrame(
Seq(
Record("One", 0, 5),
Record("One", 10, 15),
Record("One", 5, 8),
Record("Two", 10, 25),
Record("Two", 40, 45),
Record("Three", 30, 35)
)
);
df.repartition(df("typ")).sortWithinPartitions(df("start")).show();
}
}
When I get done I would like to be able to output a dataframe like this:
typ start end
--- ----- ---
One 0 8
One 10 15
Two 10 25
Two 40 45
Three 30 35
I guessed that partitioning by the 'typ' value would give me partitions with each distinct data value, 1-1, E.G. in the sample I would end up with three partions, one each for 'One', 'Two' and 'Three'. Furthermore, the sortWithinPartitions call is intended to give me each partition in sorted order on 'start' so that I can iterate from the beginning to the end and record gaps. That last part is where I am stuck. Is this possible? If not, is there another approach that is?
答案 0 :(得分:0)
我建议跳过重新分区和排序步骤,直接跳转到分布式压缩合并排序(我刚刚发明了算法的名称,就像算法本身一样)
以下是应该用作reduce
操作的算法部分:
type Gap = (Int, Int)
def mergeIntervals(as: List[Gap], bs: List[Gap]): List[Gap] = {
require(!as.isEmpty, "as must be non-empty")
require(!bs.isEmpty, "bs must be non-empty")
@annotation.tailrec
def mergeRec(
gaps: List[Gap],
gapStart: Int,
gapEndAccum: Int,
as: List[Gap],
bs: List[Gap]
): List[Gap] = {
as match {
case Nil => {
bs match {
case Nil => (gapStart, gapEndAccum) :: gaps
case notEmpty => mergeRec(gaps, gapStart, gapEndAccum, bs, Nil)
}
}
case (a0, a1) :: at => {
if (a0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max a1, at, bs)
} else {
bs match {
case Nil => mergeRec((gapStart, gapEndAccum) :: gaps, a0, gapEndAccum max a1, at, bs)
case (b0, b1) :: bt => if (b0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max b1, as, bt)
} else {
if (a0 < b0) {
mergeRec((gapStart, gapEndAccum) :: gaps, a0, a1, at, bs)
} else {
mergeRec((gapStart, gapEndAccum) :: gaps, b0, b1, as, bt)
}
}
}
}
}
}
}
val (a0, a1) :: at = as
val (b0, b1) :: bt = bs
val reverseRes =
if (a0 < b0)
mergeRec(Nil, a0, a1, at, bs)
else
mergeRec(Nil, b0, b1, as, bt)
reverseRes.reverse
}
它的工作原理如下:
println(mergeIntervals(
List((0, 3), (4, 7), (9, 11), (15, 16), (18, 22)),
List((1, 2), (4, 5), (6, 10), (12, 13), (15, 17))
))
// Outputs:
// List((0,3), (4,11), (12,13), (15,17), (18,22))
现在,如果你将它与Spark的并行reduce
组合,
val mergedIntervals = df.
as[(String, Int, Int)].
rdd.
map{case (t, s, e) => (t, List((s, e)))}. // Convert start end to list with one interval
reduceByKey(mergeIntervals). // perform parallel compressed merge-sort
flatMap{ case (k, vs) => vs.map(v => (k, v._1, v._2))}.// explode resulting lists of merged intervals
toDF("typ", "start", "end") // convert back to DF
mergedIntervals.show()
你获得类似并行合并排序的东西,它直接适用于整数序列的压缩表示(因此名称)。
结果:
+-----+-----+---+
| typ|start|end|
+-----+-----+---+
| Two| 10| 25|
| Two| 40| 45|
| One| 0| 8|
| One| 10| 15|
|Three| 30| 35|
+-----+-----+---+
<强>讨论强>
mergeIntervals
方法实现了一个可交换的关联操作,用于合并已按递增顺序排序的非重叠区间列表。然后合并所有重叠间隔,并再次按递增顺序存储。可以在reduce
步骤中重复此过程,直到合并所有间隔序列。
该算法的有趣特性是它最大限度地压缩了每个中间减少的结果。因此,如果你有很多重叠的区间,那么这个算法实际上可能更快,然后是基于输入区间排序的其他算法。
但是,如果你有很多间隔,很少有重叠,那么这个方法可能会耗尽内存而根本无法工作,因此必须使用其他算法,首先对间隔进行排序,然后进行某种扫描并在本地合并相邻的间隔。因此,这是否有效取决于用例。
完整代码
val df = Seq(
("One", 0, 5),
("One", 10, 15),
("One", 5, 8),
("Two", 10, 25),
("Two", 40, 45),
("Three", 30, 35)
).toDF("typ", "start", "end")
type Gap = (Int, Int)
/** The `merge`-step of a variant of merge-sort
* that works directly on compressed sequences of integers,
* where instead of individual integers, the sequence is
* represented by sorted, non-overlapping ranges of integers.
*/
def mergeIntervals(as: List[Gap], bs: List[Gap]): List[Gap] = {
require(!as.isEmpty, "as must be non-empty")
require(!bs.isEmpty, "bs must be non-empty")
// assuming that `as` and `bs` both are either lists with a single
// interval, or sorted lists that arise as output of
// this method, recursively merges them into a single list of
// gaps, merging all overlapping gaps.
@annotation.tailrec
def mergeRec(
gaps: List[Gap],
gapStart: Int,
gapEndAccum: Int,
as: List[Gap],
bs: List[Gap]
): List[Gap] = {
as match {
case Nil => {
bs match {
case Nil => (gapStart, gapEndAccum) :: gaps
case notEmpty => mergeRec(gaps, gapStart, gapEndAccum, bs, Nil)
}
}
case (a0, a1) :: at => {
if (a0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max a1, at, bs)
} else {
bs match {
case Nil => mergeRec((gapStart, gapEndAccum) :: gaps, a0, gapEndAccum max a1, at, bs)
case (b0, b1) :: bt => if (b0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max b1, as, bt)
} else {
if (a0 < b0) {
mergeRec((gapStart, gapEndAccum) :: gaps, a0, a1, at, bs)
} else {
mergeRec((gapStart, gapEndAccum) :: gaps, b0, b1, as, bt)
}
}
}
}
}
}
}
val (a0, a1) :: at = as
val (b0, b1) :: bt = bs
val reverseRes =
if (a0 < b0)
mergeRec(Nil, a0, a1, at, bs)
else
mergeRec(Nil, b0, b1, as, bt)
reverseRes.reverse
}
val mergedIntervals = df.
as[(String, Int, Int)].
rdd.
map{case (t, s, e) => (t, List((s, e)))}. // Convert start end to list with one interval
reduceByKey(mergeIntervals). // perform parallel compressed merge-sort
flatMap{ case (k, vs) => vs.map(v => (k, v._1, v._2))}.// explode resulting lists of merged intervals
toDF("typ", "start", "end") // convert back to DF
mergedIntervals.show()
<强>测试强>
对mergeIntervals
的实施进行了一些测试。如果你想将它实际纳入你的代码库,这里至少是一个重复随机测试的草图:
def randomIntervalSequence(): List[Gap] = {
def recHelper(acc: List[Gap], open: Option[Int], currIdx: Int): List[Gap] = {
if (math.random > 0.999) acc.reverse
else {
if (math.random > 0.90) {
if (open.isEmpty) {
recHelper(acc, Some(currIdx), currIdx + 1)
} else {
recHelper((open.get, currIdx) :: acc, None, currIdx + 1)
}
} else {
recHelper(acc, open, currIdx + 1)
}
}
}
recHelper(Nil, None, 0)
}
def intervalsToInts(is: List[Gap]): List[Int] = is.flatMap{ case (a, b) => a to b }
var numNonTrivialTests = 0
while(numNonTrivialTests < 1000) {
val as = randomIntervalSequence()
val bs = randomIntervalSequence()
if (!as.isEmpty && !bs.isEmpty) {
numNonTrivialTests += 1
val merged = mergeIntervals(as, bs)
assert((intervalsToInts(as).toSet ++ intervalsToInts(bs)) == intervalsToInts(merged).toSet)
}
}
你显然必须用更文明的东西替换原始assert
,具体取决于你的框架。