将行转置到Spark SQL中的列(pyspark)

时间:2017-10-25 07:46:05

标签: sql pyspark apache-spark-sql

我想在Spark中进行以下转换我的目标是获取输出,我希望如果我可以进行中间转换,我可以轻松获得输出。有关如何将行转换为列的任何想法都会非常有用。

RowID  Name  Place
1      Gaga India,US,UK
1      Katy UK,India,Europe
1      Bey  Europe
2      Gaga Null
2      Katy India,Europe
2      Bey  US
3      Gaga Europe
3      Katy US
3      Bey  Null

Output:

RowID   Id  Gaga    Katy    Bey
1       1   India   UK      Europe
1       2   US      India   Null
1       3   UK      Europe  Null
2       1   Null    India   US
2       2   Null    Europe  Null
3       1   Europe  US      Null


Intermediate Output:

RowID   Gaga         Katy               Bey
1       India,US,UK  UK,India,Europe    Europe
2       Null         India,Europe       US
3       Europe       US                 Null

2 个答案:

答案 0 :(得分:3)

使用Dataframe函数和UDF,我已经尝试过了。希望它可以帮到你。

>>> from pyspark.sql import functions as F
>>> from pyspark.sql.types import IntegerType
>>> from functools import reduce
>>> from pyspark.sql import DataFrame
>>> from pyspark.sql import Window
>>> l = [(1,'Gaga','India,US,UK'),(1,'Katy','UK,India,Europe'),(1,'Bey','Europe'),(2,'Gaga',None),(2,'Katy','India,Europe'),(2,'Bey','US'),(3,'Gaga','Europe'),
... (3,'Katy','US'),(3,'Bey',None)]
>>> df = spark.createDataFrame(l,['RowID','Name','Place'])
>>> df = df.withColumn('Placelist',F.split(df.Place,','))
>>> df.show()
+-----+----+---------------+-------------------+
|RowID|Name|          Place|          Placelist|
+-----+----+---------------+-------------------+
|    1|Gaga|    India,US,UK|    [India, US, UK]|
|    1|Katy|UK,India,Europe|[UK, India, Europe]|
|    1| Bey|         Europe|           [Europe]|
|    2|Gaga|           null|               null|
|    2|Katy|   India,Europe|    [India, Europe]|
|    2| Bey|             US|               [US]|
|    3|Gaga|         Europe|           [Europe]|
|    3|Katy|             US|               [US]|
|    3| Bey|           null|               null|
+-----+----+---------------+-------------------+

>>> udf1 = F.udf(lambda x : len(x) if x is not None else 0,IntegerType())
>>> maxlen = df.agg(F.max(udf1('Placelist'))).first()[0]
>>> df1 = df.groupby('RowID').pivot('Name').agg(F.first('Placelist'))
>>> df1.show()
+-----+--------+---------------+-------------------+
|RowID|     Bey|           Gaga|               Katy|
+-----+--------+---------------+-------------------+
|    1|[Europe]|[India, US, UK]|[UK, India, Europe]|
|    3|    null|       [Europe]|               [US]|
|    2|    [US]|           null|    [India, Europe]|
+-----+--------+---------------+-------------------+

>>> finaldf = reduce(
...     DataFrame.unionAll,
...     (df1.select("RowID", F.col("Bey").getItem(i), F.col("Gaga").getItem(i),F.col("Katy").getItem(i) )
...         for i in range(maxlen))
... ).toDF(*df1.columns).na.drop('all',subset=df1.columns[1:]).orderBy('RowID')
>>> w = Window.partitionBy('RowID').orderBy('Bey')
>>> finaldf = finaldf.withColumn('ID',F.row_number().over(w))
>>> finaldf.select('RowID','ID','Gaga','Katy','Bey').show()
+-----+---+------+------+------+
|RowID| ID|  Gaga|  Katy|   Bey|
+-----+---+------+------+------+
|    1|  1|    US| India|  null|
|    1|  2|    UK|Europe|  null|
|    1|  3| India|    UK|Europe|
|    2|  1|  null|Europe|  null|
|    2|  2|  null| India|    US|
|    3|  1|Europe|    US|  null|
+-----+---+------+------+------+

答案 1 :(得分:-1)

不使用UDF的替代解决方案:

from pyspark.sql import Row
from pyspark.sql.types import StructField, StructType, StringType, IntegerType
from pyspark.sql.window import Window
from pyspark.sql.functions import create_map, explode, struct, split, row_number, to_json
from functools import reduce

/* DataFrame Schema */

dfSchema = StructType([
    StructField('RowID', IntegerType()),
    StructField('Name', StringType()),
    StructField('Place', StringType())
])

/* Raw Data */

rowID_11 = Row(1, 'Gaga', 'India,US,UK')
rowID_12 = Row(1, 'Katy', 'UK,India,Europe')
rowID_13 = Row(1, 'Bey', 'Europe')
rowID_21 = Row(2, 'Gaga', None)
rowID_22 = Row(2, 'Katy', 'India,Europe')
rowID_23 = Row(2, 'Bey', 'US')
rowID_31 = Row(3, 'Gaga', 'Europe')
rowID_32 = Row(3, 'Katy', 'US')
rowID_33 = Row(3, 'Bey', None)

rowList = [rowID_11, rowID_12, rowID_13, 
           rowID_21, rowID_22, rowID_23, 
           rowID_31, rowID_32, rowID_33]

/* Create initial DataFrame */

df = spark.createDataFrame(rowList, dfSchema)
df.show()

+-----+----+---------------+ |RowID|Name| Place| +-----+----+---------------+ | 1|Gaga| India,US,UK| | 1|Katy|UK,India,Europe| | 1| Bey| Europe| | 2|Gaga| null| | 2|Katy| India,Europe| | 2| Bey| US| | 3|Gaga| Europe| | 3|Katy| US| | 3| Bey| null| +-----+----+---------------+

/* Use create_map, struct and to_json to create intermediate output */

jsonDFCol = df.select(
                 to_json(
                 create_map('Name', 
                            struct('RowID', 'Place')))\
                                .alias('name_place'))

jsonList = [js[0] for js in jsonDFCol.rdd.collect()] 
jsonDF = spark.read.json(sc.parallelize(jsonList))

intermediateList = [jsonDF .selectExpr(f'{name}.RowID', f'{name}.Place AS {name}')\
    .where('RowID is not Null') for name in jsonDF .columns]

intermediateDF = reduce(lambda curr, nxt: 
                        curr.join(nxt, on='RowID'), 
                        intermediateList).sort('RowID')\
                        .select('RowID', 'Gaga', 'Katy', 'Bey')

intermediateDF.show()

+-----+-----------+---------------+------+ |RowID| Gaga| Katy| Bey| +-----+-----------+---------------+------+ | 1|India,US,UK|UK,India,Europe|Europe| | 2| null| India,Europe| US| | 3| Europe| US| null| +-----+-----------+---------------+------+

/* Use window to create Id column */

rowWindow = Window.partitionBy('RowID').orderBy('RowID') 

/* Use split and explode functions to obtain final output */

finalDFList = \
[intermediateDF\
    .select('RowID', 
            explode(split(intermediateDF[col_], ',')).alias(col_))\
            .withColumn('id', row_number().over(rowWindow)) 
for col_ in intermediateDF.columns[1:]]

finalDFID = reduce(lambda curr, nxt: curr.select('RowID', 'Id')\
    .unionAll(nxt.select('RowId', 'Id')), finalDFList)

finalDF = reduce(lambda curr, nxt: 
                        curr.join(nxt, on=['RowID', 'Id'], how='left'), 
                        finalDFList, finalDFID).distinct()\
                        .sort('RowId', 'Id')\
                        .select('RowID', 'Id', 
                                'Gaga', 'Katy', 'Bey')

finalDF.show()

+-----+---+------+------+------+ |RowID| Id| Gaga| Katy| Bey| +-----+---+------+------+------+ | 1| 1| India| UK|Europe| | 1| 2| US| India| null| | 1| 3| UK|Europe| null| | 2| 1| null| India| US| | 2| 2| null|Europe| null| | 3| 1|Europe| US| null| +-----+---+------+------+------+