Spark数据帧聚合在多个列上

时间:2016-03-24 02:31:36

标签: pyspark apache-spark-sql spark-dataframe

Actually I am working on pyspark code. My dataframe is

+-------+--------+--------+--------+--------+
|element|collect1|collect2|collect3|collect4|
+-------+--------+--------+--------+--------+
|A1     |   1.02 |    2.6 |   5.21 |    3.6 |
|A2     |   1.61 |   2.42 |   4.88 |   6.08 |
|B1     |   1.66 |   2.01 |    5.0 |    4.3 |
|C2     |   2.01 |   1.85 |   3.42 |   4.44 |
+-------+--------+--------+--------+--------+

我需要通过聚合所有collectX列来找到每个元素的mean和stddev。最终结果如下:

+-------+--------+--------+
|element|mean    |stddev  |
+-------+--------+--------+
|A1     |   3.11 |   1.76 |
|A2     |   3.75 |   2.09 |
|B1     |   3.24 |   1.66 |
|C2     |   2.93 |   1.23 |
+-------+--------+--------+

以下代码细分了各列的所有平均值 df.groupBy("元件&#34)。平均()表示()。而不是为每一列做,是否可以汇总所有列?

+-------+-------------+-------------+-------------+-------------+
|element|avg(collect1)|avg(collect2)|avg(collect3)|avg(collect4)|
+-------+-------------+-------------+-------------+-------------+
|A1     |   1.02      |   2.6       |   5.21      |    3.6      |
|A2     |   1.61      |   2.42      |   4.88      |   6.08      |
|B1     |   1.66      |   2.01      |    5.0      |    4.3      |
|C2     |   2.01      |   1.85      |   3.42      |   4.44      |
+-------+-------------+-------------+-------------+-------------+

我尝试使用describe函数,因为它具有完整的聚合函数,但仍然显示为单独的列 df.groupBy("元件&#34)。平均()描述()显示()

感谢

2 个答案:

答案 0 :(得分:1)

Spark允许您收集每列的所有类型的统计信息。您正在尝试计算每行的统计数据。在这种情况下,您可以使用udf来破解某些内容。这是一个例子:D

$ pyspark
>>> from pyspark.sql.types import DoubleType
>>> from pyspark.sql.functions import array, udf
>>>
>>> mean = udf(lambda v: sum(v) / len(v), DoubleType())
>>> df = sc.parallelize([['A1', 1.02, 2.6, 5.21, 3.6], ['A2', 1.61, 2.42, 4.88, 6.08]]).toDF(['element', 'collect1', 'collect2', 'collect3', 'collect4'])
>>> df.show()
+-------+--------+--------+--------+--------+
|element|collect1|collect2|collect3|collect4|
+-------+--------+--------+--------+--------+
|     A1|    1.02|     2.6|    5.21|     3.6|
|     A2|    1.61|    2.42|    4.88|    6.08|
+-------+--------+--------+--------+--------+
>>> df.select('element', mean(array(df.columns[1:])).alias('mean')).show()
+-------+------+
|element|  mean|
+-------+------+
|     A1|3.1075|
|     A2|3.7475|
+-------+------+

答案 1 :(得分:0)

您是否尝试过将这些列添加到一起并可能除以4?

SELECT avg((collect1 + collect2 + collect3 + collect4) / 4),
  stddev((collect1 + collect2 + collect3 + collect4) / 4)

这不会完全按照你的意愿去做,但要理解。

不确定您的语言,但如果您对硬编码感到满意,则可以随时快速构建查询:

val collectColumns = df.columns.filter(_.startsWith("collect"))
val stmnt = "SELECT avg((" + collectColumns.mkString(" + ") + ") / " + collectColumns.length + "))"

你明白了。