如何避免在Spark中广播大型查找表

时间:2016-07-13 17:51:25

标签: scala apache-spark apache-spark-sql spark-dataframe

你能帮我避免广播大型查询表吗?我有一张测量表:

Measurement     Value
x1              5.1
x2              8.9
x1              9.1
x3              4.4
x2              2.1
...

一对配对列表:

P1              P2
x1              x2
x2              x3
...

任务是获取每对元素的所有值并将它们放入魔术函数中。这就是我通过用测量值广播大表来解决它的方法。

case class Measurement(measurement: String, value: Double)
case class Candidate(c1: String, c2: String)

val measurements = Seq(Measurement("x1", 5.1), Measurement("x2", 8.9), 
  Measurement("x1", 9.1), Measurement("x3", 4.4))
val candidates = Seq(Candidate("x1", "x2"), Candidate("x2", "x3"))

// create data frames
val dfm = sqc.createDataFrame(measurements)
val dfc = sqc.createDataFrame(candidates)

// broadcast lookup table
val lookup = sc.broadcast(dfm.rdd.map(r => (r(0), r(1))).collect())

// udf: run magic test with every candidate
val magic: ((String, String) => Double) = (c1: String, c2: String) => {
  val lt = lookup.value

  val c1v = lt.filter(_._1 == c1).map(_._2).map(_.asInstanceOf[Double])
  val c2v = lt.filter(_._1 == c2).map(_._2).map(_.asInstanceOf[Double])

  new Foo().magic(c1v, c2v)
}

val sq1 = udf(magic)
val dfks = dfc.withColumn("magic", sq1(col("c1"), col("c2")))

你可以猜到我对这个解决方案并不满意。对于每一对我过滤查找表两次,这不是快速也不优雅。我正在使用Spark 1.6.1。

2 个答案:

答案 0 :(得分:2)

另一种方法是使用RDD和join。但不确定在性能方面有什么好处。

case class Measurement(measurement: String, value: Double)
case class Candidate(c1: String, c2: String)

val measurements = Seq(Measurement("x1", 5.1), Measurement("x2", 8.9), 
Measurement("x1", 9.1), Measurement("x3", 4.4))
val candidates = Seq(Candidate("x1", "x2"), Candidate("x2", "x3"))

val rdm = sc.parallelize(measurements).map(r => (r.measurement, r.value)).groupByKey().cache()
val rdc = sc.parallelize(candidates).map(r => (r.c1, r.c2)).cache()

val firstColJoin = rdc.join(rdm).values
val secondColJoin = firstColJoin.join(rdm).values

secondColJoin.map { case (c1v, c2v) => new Foo().magic(c1v, c2v) }

答案 1 :(得分:0)

感谢您的所有评论。我阅读了评论,做了一些研究并研究了zero323个帖子。

我目前的解决方案是使用两个joins和一个UserDefinedAggregateFunction

object GroupValues extends UserDefinedAggregateFunction {
  def inputSchema = new StructType().add("x", DoubleType)
  def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
  def dataType = ArrayType(DoubleType)
  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer) = {
    buffer.update(0, ArrayBuffer.empty[Double])
  }

  def update(buffer: MutableAggregationBuffer, input: Row) = {
    if (!input.isNullAt(0))
      buffer.update(0, buffer.getSeq[Double](0) :+ input.getDouble(0))
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1.update(0, buffer1.getSeq[Double](0) ++ buffer2.getSeq[Double](0))
  }

  def evaluate(buffer: Row) = buffer.getSeq[Double](0)
}

// join data for candidate one
val j1 = dfc.join(dfm, dfc("c1") === dfm("measurement"))

// aggregate all c1 values to an array
val j1v = j1.groupBy(col("c1"), col("c2")).agg(GroupValues(col("value"))
  .alias("c1-values"))

// join data for candidate two
val j2 = j1v.join(dfm, j1v("c2") === dfm("measurement"))

// aggregate all c2 values to an array
val j2v = j2.groupBy(col("c1"), col("c2"), col("c1-values"))
  .agg(GroupValues(col("value")).alias("c2-values"))

下一步是使用collect_list代替UserDefinedAggregateFunction