使用Statistic.stat时如何避免收集

时间:2015-12-15 13:52:16

标签: java scala apache-spark

当我计算数据的方差时,我必须先收集,还有其他方法吗?

我的数据格式:

1   2   3 
1   4   5
4   5   6
4   7   8
7   8   9
10  11  12
10  13  14
10  1   2
1   100 100
10  11  2
10  11  2
1   2   5
4   7   6   

代码:

val conf = new SparkConf().setAppName("hh")
conf.setMaster("local[3]")
val sc = new SparkContext(conf)
val data = sc.textFile("/home/hadoop4/Desktop/i.txt")
  .map(_.split("\t")).map(f => f.map(f => f.toDouble))
  .map(f => ("k"+f(0),f(1)))
//data:RDD[(String,Double)]
val dataArr = data.map(f=>(f._1,ArrayBuffer(f._2)))
//dataArr  RDD[(String,ArrayBuffer[Double])]

dataArr.collect().foreach(println(_))
//output
(k1.0,ArrayBuffer(2.0))
(k1.0,ArrayBuffer(4.0))
(k4.0,ArrayBuffer(5.0))
(k4.0,ArrayBuffer(7.0))
(k7.0,ArrayBuffer(8.0))
(k10.0,ArrayBuffer(11.0))
(k10.0,ArrayBuffer(13.0))
(k10.0,ArrayBuffer(1.0))
(k1.0,ArrayBuffer(100.0))
(k10.0,ArrayBuffer(11.0))
(k10.0,ArrayBuffer(11.0))
(k1.0,ArrayBuffer(2.0))
(k4.0,ArrayBuffer(7.0))


val dataArrRed = dataArr.reduceByKey((x,y)=>x++=y)
//dataArrRed :RDD[(String,ArrayBuffer[Double])]
dataArrRed.collect().foreach(println(_))
//output
(k1.0,ArrayBuffer(2.0, 4.0, 100.0, 2.0))
(k7.0,ArrayBuffer(8.0))
(k10.0,ArrayBuffer(11.0, 13.0, 1.0, 11.0, 11.0))
(k4.0,ArrayBuffer(5.0, 7.0, 7.0))

val dataARM = dataArrRed.collect().map(
f=>(f._1,sc.makeRDD(f._2,2)))
val dataARMM = dataARM.map(
f=>(f._1,(f._2.variance(),f._2.max(),f._2.min())))
.foreach(println(_))
sc.stop()

//output
(k1.0,(1777.0,100.0,2.0))
(k7.0,(0.0,8.0,8.0))
(k10.0,(18.24,13.0,1.0))
(k4.0,(0.8888888888888888,7.0,5.0))

//更新,现在我同时计算第二列和第三列并将它们放入一个数组(f(1),f(2)),然后用它转换成RDD和aggregateByKey, '零值'是Array(新的StatCounter(),新的StatCounter()),它有一些问题。

val dataArray2 = dataString.split("\\n")
 .map(_.split("\\s+")).map(_.map(_.toDouble))
 .map(f => ("k" + f(0), Array(f(1),f(2))))
val data2 = sc.parallelize(dataArray2)
val dataStat2 = data2.aggregateByKey(Array(new StatCounter(),new 
StatCounter()))
({
(s,v)=>(
s(0).merge(v(0)),s(1).merge(v(1))
)
},{
(s,t)=>(
s(0).merge(v(0)),s(1).merge(v(1))
)})

这是错的。我可以使用Array(new StatCounter(),new StatCounter())吗?感谢。

2 个答案:

答案 0 :(得分:1)

工作的例子。事实证明它是一个单行,另一行将它映射到OP的格式。

获取数据的方式略有不同(测试更方便但结果相同)

val dataString = """1   2   3 
1   4   5
4   5   6
4   7   8
7   8   9
10  11  12
10  13  14
10  1   2
1   100 100
10  11  2
10  11  2
1   2   5
4   7   6  
""".trim

val dataArray = dataString.split("\\n")
 .map(_.split("\\s+")).map(_.map(_.toDouble))
 .map(f => ("k" + f(0), f(1)))
val data = sc.parallelize(dataArray)

按键构建统计数据

val dataStats = data.aggregateByKey(new StatCounter())
                                    ({(s,v)=>s.merge(v)}, {(s,t)=>s.merge(t)})

或者,稍短但可能过于棘手:

val dataStats = data.aggregateByKey(new StatCounter())(_ merge _, _ merge _)

重新格式化为OP的格式并打印

val result = dataStats.map(f=>(f._1,(f._2.variance,f._2.max,f._2.min)))
.foreach(println(_))

输出,除了一些舍入误差外。

(k1.0,(1776.9999999999998,100.0,2.0))
(k7.0,(0.0,8.0,8.0))
(k10.0,(18.240000000000002,13.0,1.0))
(k4.0,(0.888888888888889,7.0,5.0))

编辑:包含两列的版本

  val dataArray = dataString.split("\\n")
    .map(_.split("\\s+")).map(_.map(_.toDouble))
    .map(f => ("k" + f(0), Array(f(1), f(2))))
  val data = sc.parallelize(dataArray)

  val dataStats = data.aggregateByKey(Array(new StatCounter(), new StatCounter()))({(s, v)=> Array(s(0) merge v(0), s(1) merge v(1))}, {(s, t)=> Array(s(0) merge t(0), s(1) merge t(1))})

  val result = dataStats.map(f => (f._1, (f._2(0).variance, f._2(0).max, f._2(0).min), (f._2(1).variance, f._2(1).max, f._2(1).min)))
    .foreach(println(_))

输出

(k1.0,(1776.9999999999998,100.0,2.0),(1716.6875,100.0,3.0))
(k7.0,(0.0,8.0,8.0),(0.0,9.0,9.0))
(k10.0,(18.240000000000002,13.0,1.0),(29.439999999999998,14.0,2.0))
(k4.0,(0.888888888888889,7.0,5.0),(0.888888888888889,8.0,6.0))

EDIT2:“n” - 列版本

val n = 2

  val dataStats = data.aggregateByKey(List.fill(n)(new StatCounter()))(
      {(s, v)=> (s zip v).map{case (si, vi) => si merge vi}},
      {(s, t)=> (s zip t).map{case (si, ti) => si merge ti}})

  val result = dataStats.map(f => (f._1, f._2.map(x => (x.variance, x.max, x.min))))
    .foreach(println(_))

输出与上面相同,但如果您有更多列,则可以更改n。如果任何行中的数组少于n个元素,它将会中断。

答案 1 :(得分:0)

我只想使用stats对象(类StatCounter)。然后,我会:

  • 解析文件,拆分每一行
  • 创建元组并对RDD进行分区
  • 按键使用聚合并收集作为统计对象的RDD