我喜欢编写一个在连接两个Spark数据集时处理数据偏斜的函数。
DataFrames解决方案很简单:
def saltedJoin(left: DataFrame, right: DataFrame, e: Column, kind: String = "inner", replicas: Int): DataFrame = {
val saltedLeft = left.
withColumn("__temporarily__", typedLit((0 until replicas).toArray)).
withColumn("__skew_left__", explode($"__temporarily__")).
drop($"__temporarily__").
repartition($"__skew_left__")
val saltedRight = right.
withColumn("__temporarily__", rand).
withColumn("__skew_right__", ($"__temporarily__" * replicas).cast("bigint")).
drop("__temporarily__").
repartition($"__skew_right__")
saltedLeft.
join(saltedRight, $"__skew_left__" === $"__skew_right__" && e, kind).
drop($"__skew_left__").
drop($"__skew_right__")
}
您将使用如下功能:
val joined = saltedJoin(df alias "l", df alias "r", $"l.x" === $"r.x", replicas = 5)
但是,我不知道如何为Dataset
实例编写联接函数。到目前为止,我已经写了以下内容:
def saltedJoinWith[A: Encoder : TypeTag, B: Encoder : TypeTag](left: Dataset[A],
right: Dataset[B],
e: Column,
kind: String = "inner",
replicas: Int): Dataset[(A, B)] = {
val spark = left.sparkSession
val random = new Random()
import spark.implicits._
val saltedLeft: Dataset[(A, Int)] = left flatMap (a => 0 until replicas map ((a, _)))
val saltedRight: Dataset[(B, Int)] = right map ((_, random.nextInt(replicas)))
saltedLeft.joinWith(saltedRight, saltedLeft("_2") === saltedRight("_2") && e, kind).map(x => (x._1._1, x._2._1))
}
这显然不是正确的解决方案,因为连接条件e
没有指向saltedRight
和saltedLeft
中定义的列。它指向saltedRight._1
和saltedLeft._1
中的列。因此,例如,val j = saltedJoinWith(ds alias "l", ds alias "r", $"l.x" === $"r.x", replicas = 5)
在运行时将失败,但以下情况除外:
org.apache.spark.sql.AnalysisException: cannot resolve '`l.x`' given input columns: [_1, _2];;
我正在使用Apache Spark 2.2。