遍历PySpark GroupedData

时间:2018-07-23 05:38:55

标签: python pyspark apache-spark-sql

让我们假设原始数据如下:

print df[(df['numbers']<3)|(df['numbers']=='.')]

(参考:Python - splitting dataframe into multiple dataframes based on column values and naming them with those values

我希望基于列值(例如Region)获得子数据帧列表,例如:

Competitor  Region  ProductA  ProductB
Comp1       A       £10       £15
Comp1       B       £11       £16
Comp1       C       £11       £15
Comp2       A       £9        £16
Comp2       B       £12       £14
Comp2       C       £14       £17
Comp3       A       £11       £16
Comp3       B       £10       £15
Comp3       C       £12       £15

在Python中,我可以这样做:

df_A :

Competitor  Region  ProductA  ProductB
Comp1       A       £10       £15
Comp2       A       £9        £16
Comp3       A       £11       £16

如果df是Pyspark df,我可以做同样的迭代吗?

在Pyspark中,一旦执行df.groupBy(“ Region”),我就会获得GroupedData。我不需要像count,mean等之类的任何聚合。我只需要子数据帧的列表,每个子数据帧都有相同的“ Region”值。可能吗?

2 个答案:

答案 0 :(得分:5)

在分组列中唯一值列表足够小以适合驱动程序内存的前提下,以下方法应为您工作。希望这会有所帮助!

import pyspark.sql.functions as F
import pandas as pd

# Sample data 
df = pd.DataFrame({'region': ['aa','aa','aa','bb','bb','cc'],
                   'x2': [6,5,4,3,2,1],
                   'x3': [1,2,3,4,5,6]})
df = spark.createDataFrame(df)

# Get unique values in the grouping column
groups = [x[0] for x in df.select("region").distinct().collect()]

# Create a filtered DataFrame for each group in a list comprehension
groups_list = [df.filter(F.col('region')==x) for x in groups]

# show the results
[x.show() for x in groups_list]

结果:

+------+---+---+
|region| x2| x3|
+------+---+---+
|    cc|  1|  6|
+------+---+---+

+------+---+---+
|region| x2| x3|
+------+---+---+
|    bb|  3|  4|
|    bb|  2|  5|
+------+---+---+

+------+---+---+
|region| x2| x3|
+------+---+---+
|    aa|  6|  1|
|    aa|  5|  2|
|    aa|  4|  3|
+------+---+---+

答案 1 :(得分:0)

还需要分组的名称,因此我将其放在数组中作为第一个元素。

valuesA = [('Pirate',1),('Monkey',2),('Ninja',3),('Spaghetti',4),('Pirate',5)]
TableA = sqlContext.createDataFrame(valuesA,['name','id'])

valuesB = [('Pirate',1),('Rutabaga',2),('Ninja',3),('Darth Vader',4),('Pirate',5)]
TableB = sqlContext.createDataFrame(valuesB,['name','id'])

TableA.show()
TableB.show()

ta = TableA.alias('ta')
tb = TableB.alias('tb')

df = ta.join(tb, (ta.name == tb.name) & (ta.id == tb.id),how='full') # Could also use 'full_outer'
df.show()

# Get unique values in the grouping column
groups = [x[0] for x in df.select("ta.name").distinct().collect()]

# Create a filtered DataFrame for each group in a list comprehension
groups_list = [[x,df.filter(F.col('ta.name')==x)] for x in groups]

# show the results
for x,dfx in groups_list:  
    print(x)  
    dfx.show() 

None
+----+---+----+---+
|name| id|name| id|
+----+---+----+---+
+----+---+----+---+

Spaghetti
+---------+---+----+----+
|     name| id|name|  id|
+---------+---+----+----+
|Spaghetti|  4|null|null|
+---------+---+----+----+

Ninja
+-----+---+-----+---+
| name| id| name| id|
+-----+---+-----+---+
|Ninja|  3|Ninja|  3|
+-----+---+-----+---+

Pirate
+------+---+------+---+
|  name| id|  name| id|
+------+---+------+---+
|Pirate|  1|Pirate|  1|
|Pirate|  5|Pirate|  5|
+------+---+------+---+

Monkey
+------+---+----+----+
|  name| id|name|  id|
+------+---+----+----+
|Monkey|  2|null|null|
+------+---+----+----+