Spark Override Accumulator

时间:2015-06-25 17:18:22

标签: scala apache-spark

我似乎找不到任何覆盖Spark Accumulator的方法示例。我有一个键/值格式的数据,键是列索引。我的函数在下面过滤掉不是数字的东西。我的目标是跟踪每列的清空数量。

我有以下过滤器:

val numFilterRDD = numRDD.filter(filterNum)

    def isAllDigits(x: String) = x matches """^\d{1,}\.*\d*$"""
    def filterNum(x: (Int, String)) : Boolean = {
      accumNum.add(1)
      if(isAllDigits(x._2)) true
      else false
    }

现在解决方案太过分了,我需要在过滤器之前执行以下操作:

val originalCountNum = numRDD.map(x => (x._1, 1)).reduceByKey(_ + _).collect()

最后两者的比较。这可能使累加器能够跟踪列索引+空计数,它将删除原始计数的额外传递。

1 个答案:

答案 0 :(得分:4)

You have to use custom AccumulatorParam. For example you can use map like this:

object CountPairsParam extends AccumulatorParam[Map[Int, Int]] {

  def zero(initialValue: Map[Int, Int]): Map[Int, Int] = {
    Map.empty[Int, Int]
  }

  def addInPlace(m1: Map[Int, Int], m2: Map[Int, Int]): Map[Int, Int] = {
    val keys = m1.keys ++ m2.keys
    keys.map((k: Int) => (k -> (m1.getOrElse(k, 0) + m2.getOrElse(k, 0)))).toMap
  }
}

val rdd = sc.parallelize(List((1, -1), (2, 1), (3, 0), (3, -1), (2, 0)))
val accum = sc.accumulator(Map.empty[Int, Int])(CountPairsParam)

Inside filter you use do something like this:

val allDigits = isAllDigits(x._2)
if(allDigits) { 
    accum += Map(x._1 -> 1)
}