使用collect_list(列)将spark数据帧转换回长格式

时间:2017-11-01 17:17:29

标签: python apache-spark pyspark pyspark-sql

假设我们有数据帧iris:

import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv')
df = spark.createDataFrame(df)

我需要按物种对萼片宽度执行一些聚合函数,例如每组获得3个最高值。

import pyspark.sql.functions as F
get_max_3 = F.udf(
    lambda x: sorted(x)[-3:]
)

agged = df.groupBy('species').agg(F.collect_list('sepal_width').alias('sepal_width'))
agged = agged.withColumn('sepal_width', get_max_3('sepal_width'))

+----------+---------------+
|   species|    sepal_width|
+----------+---------------+
| virginica|[3.6, 3.8, 3.8]|
|versicolor|[3.2, 3.3, 3.4]|
|    setosa|[4.1, 4.2, 4.4]|
+----------+---------------+

现在,我如何有效地将其转换为长格式的数据帧(意味着每个物种有三行,每行对应一个值)?

有没有办法在不使用collect_list的情况下执行此操作?

1 个答案:

答案 0 :(得分:3)

要将数据框格转换为长格式,您可以使用explode;要使用此方法,您需要先修复udf,然后返回正确的类型:

from pyspark.sql.types import *
import pyspark.sql.functions as F

get_max_3 = F.udf(lambda x: sorted(x)[-3:], ArrayType(DoubleType()))

agged = agged.withColumn('sepal_width', get_max_3('sepal_width'))
agged.withColumn('sepal_width', F.explode(F.col('sepal_width'))).show()

+----------+-----------+
|   species|sepal_width|
+----------+-----------+
| virginica|        3.6|
| virginica|        3.8|
| virginica|        3.8|
|versicolor|        3.2|
|versicolor|        3.3|
|versicolor|        3.4|
|    setosa|        4.1|
|    setosa|        4.2|
|    setosa|        4.4|
+----------+-----------+

或者,如果没有收集列表并进行展开,您可以先对sepal_width列进行排名,然后根据rank进行过滤:

df.selectExpr(
    "species", "sepal_width", 
    "row_number() over (partition by species order by sepal_width desc) as rn"
).where(F.col("rn") <= 3).drop("rn").show()
+----------+-----------+
|   species|sepal_width|
+----------+-----------+
| virginica|        3.8|
| virginica|        3.8|
| virginica|        3.6|
|versicolor|        3.4|
|versicolor|        3.3|
|versicolor|        3.2|
|    setosa|        4.4|
|    setosa|        4.2|
|    setosa|        4.1|
+----------+-----------+