如何检查给定键的所有记录是否已在同一分区中?

时间:2016-12-29 10:21:55

标签: apache-spark

我希望尽可能避免按密钥对数据集进行重新分区,并知道给定密钥的所有记录是否都在同一个分区中。

Spark中是否有内置函数可以给我答案?

2 个答案:

答案 0 :(得分:1)

不是内置的,但是如果你假定特定的分区,那么很容易实现你自己的功能:

import org.apache.spark.rdd.RDD
import org.apache.spark.Partitioner
import scala.reflect.ClassTag

def checkDistribution[K : ClassTag, V : ClassTag](
   rdd: RDD[(K, V)], partitioner: Partitioner) = 
  // If partitioner is set we compare partitioners 
  rdd.partitioner.map(_ == partitioner).getOrElse {
    // Otherwise check if correct number of partitions 
    rdd.partitions.size ==  partitioner.numPartitions &&
    //  And check if distribution matches partitioner
    rdd.keys.mapPartitionsWithIndex((i, iter) => 
      Iterator(iter.forall(x => partitioner.getPartition(x) == i))
    ).fold(true)(_ && _)
  }

一些测试:

import org.apache.spark.HashPartitioner

val rdd = sc.range(0, 20, 5).map((_, None))
  • 未分区,分发无效:

    checkDistribution(rdd, new HashPartitioner(10))
    
    Boolean = false
    
  • 已分区,无效的分区程序:

    checkDistribution(
      rdd.partitionBy(new HashPartitioner(5)),
      new HashPartitioner(10)
    )
    
    Boolean = false
    
  • 已分区的有效分区程序:

    checkDistribution(
      rdd.partitionBy(new HashPartitioner(10)),
      new HashPartitioner(10)
    )
    
    Boolean = true
    
  • 未分区,有效分发:

    checkDistribution(
      rdd.partitionBy(new HashPartitioner(10)).map(identity),
      new HashPartitioner(10)
    )
    
    Boolean = true
    

如果不考虑特定的分区程序,我想到的唯一选择是需要随机播放,因此它不太可能是一种改进。

def checkDistribution[K : ClassTag, V : ClassTag](rdd: RDD[(K, V)]) =
   rdd.keys.mapPartitionsWithIndex((i, iter) => iter.map((_, i)))
     .combineByKey(
       x => Seq(x), 
       (x: Seq[Int], y: Int) => x, 
       (x: Seq[Int], y: Seq[Int]) => x ++ y)  // Should be more or less OK
     .values
     .mapPartitions(iter => Iterator(iter.forall(_.size == 1)))
     .fold(true)(_ && _)

一种可能的改进是您可以使用相同的逻辑为数据自动定义Partitioner。如果您在collectAsMap之前values并检查所有Seqs的大小为1,那么您将拥有一个有效的分区程序,以确保无网络流量。

答案 1 :(得分:1)

不是您要求的100%,但您可以使用spark_partition_id进行检查。基本上做:

withColumn("pid", spark_partition_id())

然后执行:

df.groupby(what you want to check).agg(max($"pid").as("pidmax"),min($"pid").as("pidmin")).filter($"pidmax"===$"pidmin").count()

计数将为您提供未分区的元素数量。 请注意,这是一个相对较低的成本,只是一个简单的聚合。

我不相信有一种通用方式,因为如果我们从通用来源(例如文件)中读取内容,我们就不一定知道源最初是如何分区的。

如果有类似"获得当前分区和#34;那将是很好的。这将得到明确的分区器(例如,如果我们有一个明确的重新分区命令或读取使用PartitionBy写的镶木板中的东西)作为近似值。