如何在Spark SQL中的多个列上进行数据透视?

时间:2017-07-11 13:27:52

标签: python apache-spark pyspark pyspark-sql

我需要在pyspark数据框中调换多个列。示例数据框,

 >>> d = [(100,1,23,10),(100,2,45,11),(100,3,67,12),(100,4,78,13),(101,1,23,10),(101,2,45,13),(101,3,67,14),(101,4,78,15),(102,1,23,10),(102,2,45,11),(102,3,67,16),(102,4,78,18)]
>>> mydf = spark.createDataFrame(d,['id','day','price','units'])
>>> mydf.show()
+---+---+-----+-----+
| id|day|price|units|
+---+---+-----+-----+
|100|  1|   23|   10|
|100|  2|   45|   11|
|100|  3|   67|   12|
|100|  4|   78|   13|
|101|  1|   23|   10|
|101|  2|   45|   13|
|101|  3|   67|   14|
|101|  4|   78|   15|
|102|  1|   23|   10|
|102|  2|   45|   11|
|102|  3|   67|   16|
|102|  4|   78|   18|
+---+---+-----+-----+

现在,如果我需要根据日期为每个id获取价格列,那么我可以使用pivot方法,

>>> pvtdf = mydf.withColumn('combcol',F.concat(F.lit('price_'),mydf['day'])).groupby('id').pivot('combcol').agg(F.first('price'))
>>> pvtdf.show()
+---+-------+-------+-------+-------+
| id|price_1|price_2|price_3|price_4|
+---+-------+-------+-------+-------+
|100|     23|     45|     67|     78|
|101|     23|     45|     67|     78|
|102|     23|     45|     67|     78|
+---+-------+-------+-------+-------+

因此,当我需要将单位列作为价格进行转置时,要么我需要为单位创建一个以上的数据帧,然后使用id.But加入两者,当我有更多列时,我尝试了一个函数这样做,

>>> def pivot_udf(df,*cols):
...     mydf = df.select('id').drop_duplicates()
...     for c in cols:
...        mydf = mydf.join(df.withColumn('combcol',F.concat(F.lit('{}_'.format(c)),df['day'])).groupby('id').pivot('combcol').agg(F.first(c)),'id')
...     return mydf
...
>>> pivot_udf(mydf,'price','units').show()
+---+-------+-------+-------+-------+-------+-------+-------+-------+
| id|price_1|price_2|price_3|price_4|units_1|units_2|units_3|units_4|
+---+-------+-------+-------+-------+-------+-------+-------+-------+
|100|     23|     45|     67|     78|     10|     11|     12|     13|
|101|     23|     45|     67|     78|     10|     13|     14|     15|
|102|     23|     45|     67|     78|     10|     11|     16|     18|
+---+-------+-------+-------+-------+-------+-------+-------+-------+

需要建议,如果这是一个好的做法,如果有任何其他更好的方法。提前谢谢!

3 个答案:

答案 0 :(得分:2)

问题中的解决方案是我能得到的最好的解决方案。唯一的改进是cache输入数据集,以避免双重扫描,即

mydf.cache
pivot_udf(mydf,'price','units').show()

答案 1 :(得分:1)

与spark 1.6版本一样,我认为这是唯一的方法,因为pivot只需要一列,并且有第二个属性值,您可以在其上传递该列的不同值,这将使您的代码运行得更快,否则火花必须为你运行,所以是的,这是正确的方法。

答案 2 :(得分:1)

这是涉及单个枢轴的非UDF方法(因此,只需进行一次列扫描即可识别所有唯一日期)。

mydf.groupBy('id').pivot('day').agg(F.first('price').alias('price'),F.first('units').alias('unit'))

这是结果(对不匹配的排序和命名表示歉意):

+---+-------+------+-------+------+-------+------+-------+------+               
| id|1_price|1_unit|2_price|2_unit|3_price|3_unit|4_price|4_unit|
+---+-------+------+-------+------+-------+------+-------+------+
|100|     23|    10|     45|    11|     67|    12|     78|    13|
|101|     23|    10|     45|    13|     67|    14|     78|    15|
|102|     23|    10|     45|    11|     67|    16|     78|    18|
+---+-------+------+-------+------+-------+------+-------+------+

我们仅在进行日期调整后在priceunit列上进行汇总。