如何将列添加到pyspark数据框中,该列包含基于另一列上的分组的平均值

时间:2019-01-11 02:01:36

标签: dataframe pyspark aggregate mean

它与其他一些问题相似,但有所不同。

假设我们有一个pyspark数据框df,如下所示:

+-----+------+-----+        
|col1 | col2 | col3| 
+-----+------+-----+        
|A    |   5  |  6  |
+-----+------+-----+        
|A    |   5  |  8  |
+-----+------+-----+        
|A    |   6  |  3  |
+-----+------+-----+        
|A    |   5  |  9  |
+-----+------+-----+        
|B    |   9  |  6  |
+-----+------+-----+        
|B    |   3  |  8  |
+-----+------+-----+        
|B    |   9  |  8  |
+-----+------+-----+        
|C    |  3   |  4  |
+-----+------+-----+        
|C    |  5   |  1  |
+-----+------+-----+        

我想添加另一列作为new_col,其中包含基于col1分组的col2平均值。因此,答案必须如下

   +-----+------+------+--------+
   |col1 | col2 | col3 | new_col|
   +-----+------+------+--------+
   |  A  |   5  |  6   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  8   | 5.25   |
   +-----+------+------+--------+
   |  A  |   6  |  3   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  9   | 5.25   |
   +-----+------+------+--------+
   |  B  |   9  |  6   | 7      |
   +-----+------+------+--------+
   |  B  |   3  |  8   | 7      |
   +-----+------+------+--------+    
   |  B  |   9  |  8   | 7      |
   +-----+------+------+--------+
   |  C  |   3  |  4   | 4      |
   +-----+------+------+--------+
   |  C  |   5  |  1   | 4      |
   +-----+------+------+--------+

任何帮助将不胜感激。

2 个答案:

答案 0 :(得分:1)

第1步:。创建数据框。

from pyspark.sql.functions import avg, col
from pyspark.sql.window import Window
values = [('A',5,6),('A',5,8),('A',6,3),('A',5,9),('B',9,6),('B',3,8),('B',9,8),('C',3,4),('C',5,1)]
df = sqlContext.createDataFrame(values,['col1','col2','col3'])
df.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
|   A|   5|   6|
|   A|   5|   8|
|   A|   6|   3|
|   A|   5|   9|
|   B|   9|   6|
|   B|   3|   8|
|   B|   9|   8|
|   C|   3|   4|
|   C|   5|   1|
+----+----+----+

步骤2:通过对列mean进行分组来创建另一个具有A的列。

w = Window().partitionBy('col1')
df = df.withColumn('new_col',avg(col('col2')).over(w))
df.show()
+----+----+----+-------+
|col1|col2|col3|new_col|
+----+----+----+-------+
|   B|   9|   6|    7.0|
|   B|   3|   8|    7.0|
|   B|   9|   8|    7.0|
|   C|   3|   4|    4.0|
|   C|   5|   1|    4.0|
|   A|   5|   6|   5.25|
|   A|   5|   8|   5.25|
|   A|   6|   3|   5.25|
|   A|   5|   9|   5.25|
+----+----+----+-------+

答案 1 :(得分:0)

好吧,经过大量的尝试,我可以自己回答这个问题。我在这里将答案发布给其他有类似问题的人。原始文件是此处的csv文件。

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
#reading the file
df = spark.read.csv('file's name.csv', header=True)
df.show()

输出

+-----+------+-----+        
|col1 | col2 | col3| 
+-----+------+-----+        
|A    |   5  |  6  |
+-----+------+-----+        
|A    |   5  |  8  |
+-----+------+-----+        
|A    |   6  |  3  |
+-----+------+-----+        
|A    |   5  |  9  |
+-----+------+-----+        
|B    |   9  |  6  |
+-----+------+-----+        
|B    |   3  |  8  |
+-----+------+-----+        
|B    |   9  |  8  |
+-----+------+-----+        
|C    |  3   |  4  |
+-----+------+-----+        
|C    |  5   |  1  |
+-----+------+-----+        


from pyspark.sql import functions as func
#Grouping the dataframe based on col1
col1group = df.groupBy('col1')
#Computing the average of col2 based on the grouping on col1
a= col1group.agg(func.avg("col2"))
a.show()

输出

+-----+----------+
|col1 | avg(col2)|
+-----+----------+
| A   |   5.25   |
+-----+----------+
| B   |   7.0    |
+-----+----------+
| C   |   4.0    |
+-----+----------+

现在,我们将最后一个表与初始数据帧连接起来,以生成所需的数据帧:

df=test1.join(a, on = 'lable', how = 'inner')
df.show()

输出

   +-----+------+------+---------+
   |col1 | col2 | col3 |avg(col2)|
   +-----+------+------+---------+
   |  A  |   5  |  6   | 5.25    |
   +-----+------+------+---------+
   |  A  |   5  |  8   | 5.25    |
   +-----+------+------+---------+
   |  A  |   6  |  3   | 5.25    |
   +-----+------+------+---------+
   |  A  |   5  |  9   | 5.25    |
   +-----+------+------+---------+
   |  B  |   9  |  6   | 7       |
   +-----+------+------+---------+
   |  B  |   3  |  8   | 7       |
   +-----+------+------+---------+    
   |  B  |   9  |  8   | 7       |
   +-----+------+------+---------+
   |  C  |   3  |  4   | 4       |
   +-----+------+------+---------+
   |  C  |   5  |  1   | 4       |
   +-----+------+------+---------+

现在将最后一列的名称更改为我们想要的

df = df.withColumnRenamed('avg(val1)', 'new_col')
df.show()

输出

   +-----+------+------+--------+
   |col1 | col2 | col3 | new_col|
   +-----+------+------+--------+
   |  A  |   5  |  6   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  8   | 5.25   |
   +-----+------+------+--------+
   |  A  |   6  |  3   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  9   | 5.25   |
   +-----+------+------+--------+
   |  B  |   9  |  6   | 7      |
   +-----+------+------+--------+
   |  B  |   3  |  8   | 7      |
   +-----+------+------+--------+    
   |  B  |   9  |  8   | 7      |
   +-----+------+------+--------+
   |  C  |   3  |  4   | 4      |
   +-----+------+------+--------+
   |  C  |   5  |  1   | 4      |
   +-----+------+------+--------+