在PySpark数据框的组中的列上应用函数

时间:2019-08-22 06:20:28

标签: python pyspark pyarrow

我有一个这样的PySpark数据框,

+----------+--------+---------+
|id_       | p      |   a     |
+----------+--------+---------+
|  1       | 4      |   12    |
|  1       | 3      |   14    |
|  1       | -7     |   16    |
|  1       | 5      |   11    |
|  1       | -20    |   90    |
|  1       | 5      |   120   |
|  2       |  11    |   267   |
|  2       | -98    |   124   |
|  2       | -87    |   120   |
|  2       | -1     |   44    |
|  2       |  5     |   1     |
|  2       |  7     |   23    |
-------------------------------

我也有这样的python函数,

def fun(x):
    total = 0
    result = np.empty_like(x)
    for i, y in enumerate(x):
        total += (y)
        if total < 0:
            total = 0
        result[i] = total

    return result

我想对id_列上的PySpark数据帧进行分组,并在fun列上应用函数p

我想要类似

spark_df.groupBy('id_')['p'].apply(fun)

我目前正在pyarrow的帮助下使用pandas udf进行此操作,这对于我的申请时间而言效率不高。

我正在寻找的结果是

[4, 7, 0, 5, 0, 5, 11, -98, -87, -1, 5, 7]

这是我正在寻找的结果数据框,

+----------+--------+---------+
|id_       | p      |   a     |
+----------+--------+---------+
|  1       | 4      |   12    |
|  1       | 7      |   14    |
|  1       | 0      |   16    |
|  1       | 5      |   11    |
|  1       | 0      |   90    |
|  1       | 5      |   120   |
|  2       |  11    |   267   |
|  2       | 0      |   124   |
|  2       | 0      |   120   |
|  2       | 0      |   44    |
|  2       |  5     |   1     |
|  2       |  12    |   23    |
-------------------------------

是否可以使用pyspark API本身直接执行此操作?

我可以使用pcollect_list上分组时使用id_udf汇总并列到列表中,并在其上使用explode并使用p获取列mysqldump根据需要在结果数据框中显示。

但是如何保留数据框中的其他列??

2 个答案:

答案 0 :(得分:1)

是的,您可以将上述python函数转换为Pyspark UDF。 由于您要返回整数数组,因此将返回类型指定为ArrayType(IntegerType())很重要。

下面是代码,

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType, collect_list

@udf(returnType=ArrayType(IntegerType()))
def fun(x):
    total = 0
    result = np.empty_like(x)
    for i, y in enumerate(x):
        total += (y)
        if total < 0:
            total = 0
        result[i] = total
    return result.tolist()    # Convert NumPy Array to Python List

由于udf的输入必须是列表,因此让我们基于'id'对数据进行分组,并将行转换为数组。

df = df.groupBy('id_').agg(collect_list('p'))
df = df.toDF('id_', 'p_')    # Assign a new alias name 'p_'
df.show(truncate=False)

输入数据:

+---+------------------------+
|id_|collect_list(p)         |
+---+------------------------+
|1  |[4, 3, -7, 5, -20, 5]   |
|2  |[11, -98, -87, -1, 5, 7]|
+---+------------------------+

接下来,我们将udf应用于此数据,

df.select('id_', fun(df.p_)).show(truncate=False)

输出:

+---+--------------------+
|id_|fun(p_)             |
+---+--------------------+
|1  |[4, 7, 0, 5, 0, 5]  |
|2  |[11, 0, 0, 0, 5, 12]|
+---+--------------------+

答案 1 :(得分:0)

通过以下步骤,我设法达到了所需的结果,

我的DataFrame看起来像这样,

+---+---+---+
|id_|  p|  a|
+---+---+---+
|  1|  4| 12|
|  1|  3| 14|
|  1| -7| 16|
|  1|  5| 11|
|  1|-20| 90|
|  1|  5|120|
|  2| 11|267|
|  2|-98|124|
|  2|-87|120|
|  2| -1| 44|
|  2|  5|  1|
|  2|  7| 23|
+---+---+---+

我将对id_上的数据框进行分组,并使用collect_list收集我想将该功能应用于列表的列,并像这样应用该功能,

agg_df = df.groupBy('id_').agg(F.collect_list('p').alias('collected_p'))
agg_df = agg_df.withColumn('new', fun('collected_p'))

我现在想以某种方式将agg_df合并到我的原始数据框中。为此,我将首先使用explode获取行中的列new中的值。

agg_df = agg_df.withColumn('exploded', F.explode('new'))

为了合并,我将使用monotonically_increasing_id为原始数据帧和id生成agg_df。这样,我将为每个数据帧制作idx,因为两个数据帧的monotonically_increasing_id不会相同。

agg_df = agg_df.withColumn('id_mono', F.monotonically_increasing_id())
df = df.withColumn('id_mono', F.monotonically_increasing_id())

w = Window().partitionBy(F.lit(0)).orderBy('id_mono')

df = df.withColumn('idx', F.row_number().over(w))
agg_df = agg_df.withColumn('idx', F.row_number().over(w))

df = df.join(agg_df.select('idx', 'exploded'), ['idx']).drop('id_mono', 'idx')


+---+---+---+--------+
|id_|  p|  a|exploded|
+---+---+---+--------+
|  1|  4| 12|       4|
|  1|  3| 14|       7|
|  1| -7| 16|       0|
|  1|  5| 11|       5|
|  1|-20| 90|       0|
|  1|  5|120|       5|
|  2| 11|267|      11|
|  2|-98|124|       0|
|  2|-87|120|       0|
|  2| -1| 44|       0|
|  2|  5|  1|       5|
|  2|  7| 23|      12|
+---+---+---+--------+

我不确定这是否是直接的方法。如果有人可以为此建议任何优化,那就太好了。