我的spark DataFrame有3列(id:Int,x_axis:Array [Int],y_axis:Array [Int]),下面有一些示例数据:
希望获取数据框中每行的y_axis列的基本统计信息。输出将类似于:
我试过爆炸然后描述,但无法弄清楚预期的输出。 任何帮助或参考都非常赞赏
答案 0 :(得分:1)
正如你的建议,你可以爆炸Y列,然后使用id上的窗口来计算你感兴趣的所有统计数据。不过,你想要在之后重新聚合你的数据,这样你就会产生巨大的中间结果。
Spark没有很多预定义的数组函数。因此,实现您想要的最简单的方法可能是UDF:
val extractFeatures = udf( (x : Seq[Int]) => {
val mean = x.sum.toDouble/x.size
val variance = x.map(i=> i*i).sum.toDouble/x.size - mean*mean
val std = scala.math.sqrt(variance)
Map("count" -> x.size.toDouble,
"mean" -> mean,
"std" -> std,
"min" -> x.min.toDouble,
"max" -> x.max.toDouble)
})
val df = sc
.parallelize(Seq((1,Seq(1,2,3,4,5)), (2,Seq(1,2,1,4))))
.toDF("id", "y")
.withColumn("described_y", extractFeatures('y))
.show(false)
+---+---------------+---------------------------------------------------------------------------------------------+
|id |y |described_y |
+---+---------------+---------------------------------------------------------------------------------------------+
|1 |[1, 2, 3, 4, 5]|Map(count -> 5.0, mean -> 3.0, min -> 1.0, std -> 1.4142135623730951, max -> 5.0, var -> 2.0)|
|2 |[1, 2, 1, 4] |Map(count -> 4.0, mean -> 2.0, min -> 1.0, std -> 1.224744871391589, max -> 4.0, var -> 1.5) |
+---+---------------+---------------------------------------------------------------------------------------------+
顺便说一下,你计算出的stddev实际上是方差。您需要取平方根来获得标准偏差。