我似乎找不到任何覆盖Spark Accumulator的方法示例。我有一个键/值格式的数据,键是列索引。我的函数在下面过滤掉不是数字的东西。我的目标是跟踪每列的清空数量。
我有以下过滤器:
val numFilterRDD = numRDD.filter(filterNum)
def isAllDigits(x: String) = x matches """^\d{1,}\.*\d*$"""
def filterNum(x: (Int, String)) : Boolean = {
accumNum.add(1)
if(isAllDigits(x._2)) true
else false
}
现在解决方案太过分了,我需要在过滤器之前执行以下操作:
val originalCountNum = numRDD.map(x => (x._1, 1)).reduceByKey(_ + _).collect()
最后两者的比较。这可能使累加器能够跟踪列索引+空计数,它将删除原始计数的额外传递。
答案 0 :(得分:4)
You have to use custom AccumulatorParam
. For example you can use map like this:
object CountPairsParam extends AccumulatorParam[Map[Int, Int]] {
def zero(initialValue: Map[Int, Int]): Map[Int, Int] = {
Map.empty[Int, Int]
}
def addInPlace(m1: Map[Int, Int], m2: Map[Int, Int]): Map[Int, Int] = {
val keys = m1.keys ++ m2.keys
keys.map((k: Int) => (k -> (m1.getOrElse(k, 0) + m2.getOrElse(k, 0)))).toMap
}
}
val rdd = sc.parallelize(List((1, -1), (2, 1), (3, 0), (3, -1), (2, 0)))
val accum = sc.accumulator(Map.empty[Int, Int])(CountPairsParam)
Inside filter you use do something like this:
val allDigits = isAllDigits(x._2)
if(allDigits) {
accum += Map(x._1 -> 1)
}