Spark:在某些发行版上使用Long键时,flatMap / reduceByKey似乎相当慢

时间:2017-04-03 11:26:04

标签: scala apache-spark

我正在使用Spark来处理一些语料库,我需要计算每个2克的出现次数。我开始计算元组(wordID1, wordID2)并且它工作得很好,除了由于大量的小元组对象导致的大内存使用和gc开销。然后我尝试将一对Int打包到Long中,并且gc开销确实大大减少,但运行时间也增加了几倍。

我用不同分布的随机数据进行了一些小实验。似乎性能问题只发生在指数分布式数据上。

// lines of word IDs
val data = (1 to 5000).par.map({ _ =>
  (1 to 1000) map { _ => (-1000 * Math.log(Random.nextDouble)).toInt }
}).seq

// count Tuples, fast
sc parallelize(data) flatMap { line =>
  val first = line.iterator
  val second = line.iterator.drop(1)
  for (pair <- first zip(second))
    yield (pair, 1L)
} reduceByKey { _ + _ } count()

// count Long, slow
sc parallelize(data) flatMap { line =>
  val first = line.iterator
  val second = line.iterator.drop(1)
  for ((a, b) <- first zip(second))
    yield ((a.toLong << 32) | b, 1L)
} reduceByKey { _ + _ } count()

这项工作分为两个阶段,flatMap()count()。在计算Tuple2 s时,flatMap()需要大约6s而count()需要大约2s,而在计算Long s时,flatMap()需要18s和count()需要10秒。

这对我来说没有意义,因为Long应该比Tuple2施加更少的开销。 spark是否对Long键有一些特殊性,对某些特定的发行版来说,它们的表现更差?

1 个答案:

答案 0 :(得分:1)

感谢@ SarveshKumarSingh的提示,我终于解决了这个问题。引发问题的不是Spark的Long专业化,而是Java和Spark没有正确解决它。

Java's hashCode() for Long非常简单并且强烈依赖于值的两半,而Spark的默认HashPartitioner只是按照hashCode()值对分区号进行模数分区。这些使得Spark的默认分区对Long键的分布非常敏感,尤其是当分区数量相对较小时。在我的情况下,情况会恶化,因为Long键是通过连接Int对来构建的。

解决方案非常简单,因为我们只需要以某种方式“改变”键,这使得具有相似频率的键均匀分布。

最简单的方法是使用some perfect hash function将每个键映射到另一个唯一值,并在需要原始密钥时将其转换回来。此方法仅涉及较小的代码更改,但可能无法很好地执行。我使用以下映射实现了与逐元组方法类似的性能。

val newKey = oldKey * 6364136223846793005L + 1442695040888963407L
val oldKey = (newKey - 1442695040888963407L) * -4568919932995229531L

更有效的方法是替换默认的HashPartitioner。我在flatMapreduceByKey之间使用了以下分区程序,并在实际数据上实现了两倍的性能提升。

val prevRDD = // ... flatMap ...
val nParts = prevRDD.partitioner match {
  case Some(p) => p.numPartitions
  case None => prevRDD.partitions.size
}

prevRDD partitionBy (new Partitioner {
  override def getPartition(key: Any): Int = {
    val rawMod = LongHash(key.asInstanceOf[Long]) % numPartitions
    rawMod + (if (rawMod < 0) numPartitions else 0)
  }
  override def numPartitions: Int = nParts
}) reduceByKey { _ + _ }

def LongHash(v: Long) = { // the 64bit mix function from Murmurhash3
  var k = v
  k ^= k >> 33
  k *= 0xff51afd7ed558ccdL
  k ^= k >> 33
  k *= 0xc4ceb9fe1a85ec53L
  k ^= k >> 33
  k.toInt
}