Spark RDD:如何最有效地计算统计数据?

时间:2016-10-11 15:43:58

标签: apache-spark pyspark distributed-computing rdd apache-spark-mllib

假设存在类似于以下内容的元组RDD:

(key1, 1)
(key3, 9)
(key2, 3)
(key1, 4)
(key1, 5)
(key3, 2)
(key2, 7)
...

计算与每个密钥对应的统计信息的最有效(且理想情况下,分布式)方式是什么? (目前,我正在考虑计算标准偏差/方差。)据我了解,我的选择相当于:

  1. 使用colStats function in MLLib如果认为有必要进行其他统计计算,此方法的优点是可以轻松适应以后使用其他mllib.stat函数。但是,它在Vector的RDD上运行,其中包含每列的数据,因此据我所知,这种方法需要在单个节点上收集每个键的完整值集,这看似非 - 大型数据集的理想选择。 Spark Vector是否总是暗示Vector中的数据是本地驻留在单个节点上的?
  2. 执行groupByKey,然后stats可能是重复,as a result of the groupByKey operation
  3. 执行aggregateByKey,初始化新的StatCounter,并使用StatCounter::merge作为序列和合并器功能:这是方法recommended by this StackOverflow answer,并避免选项2中的groupByKey。但是,我还没能在PySpark中找到StatCounter的好文档。
  4. 我喜欢选项1,因为它使代码更具可扩展性,因为它可以轻松地使用具有类似契约的其他MLLib函数来容纳更复杂的计算,但是如果Vector输入本身需要在本地收集数据集,然后它限制了代码可以有效操作的数据大小。在另外两个之间,选项3 看起来更有效,因为它避免了groupByKey,但我希望确认是这种情况。

    我还有其他选择吗? (我目前正在使用Python + PySpark,但如果存在语言差异,我也可以使用Java / Scala中的解决方案。)

1 个答案:

答案 0 :(得分:3)

您可以尝试reduceByKey。如果我们只想计算min()

,这非常简单
rdd.reduceByKey(lambda x,y: min(x,y)).collect()
#Out[84]: [('key3', 2.0), ('key2', 3.0), ('key1', 1.0)]

要计算mean,您首先需要创建(value, 1)元组,用于计算sum中的countreduceByKey操作。最后,我们将它们相互分开以得到mean

meanRDD = (rdd
           .mapValues(lambda x: (x, 1))
           .reduceByKey(lambda x, y: (x[0]+y[0], x[1]+y[1]))
           .mapValues(lambda x: x[0]/x[1]))

meanRDD.collect()
#Out[85]: [('key3', 5.5), ('key2', 5.0), ('key1', 3.3333333333333335)]

对于variance,您可以使用公式(sumOfSquares/count) - (sum/count)^2, 我们按以下方式翻译:

varRDD = (rdd
          .mapValues(lambda x: (1, x, x*x))
          .reduceByKey(lambda x,y: (x[0]+y[0], x[1]+y[1], x[2]+y[2]))
          .mapValues(lambda x: (x[2]/x[0] - (x[1]/x[0])**2)))

varRDD.collect()
#Out[106]: [('key3', 12.25), ('key2', 4.0), ('key1', 2.8888888888888875)]

我在虚拟数据中使用了double类型的值代替int来准确地说明计算平均值和方差:

rdd = sc.parallelize([("key1", 1.0),
                      ("key3", 9.0),
                      ("key2", 3.0),
                      ("key1", 4.0),
                      ("key1", 5.0),
                      ("key3", 2.0),
                      ("key2", 7.0)])