我有一个这样的PySpark数据框,
+----------+--------+---------+
|id_ | p | a |
+----------+--------+---------+
| 1 | 4 | 12 |
| 1 | 3 | 14 |
| 1 | -7 | 16 |
| 1 | 5 | 11 |
| 1 | -20 | 90 |
| 1 | 5 | 120 |
| 2 | 11 | 267 |
| 2 | -98 | 124 |
| 2 | -87 | 120 |
| 2 | -1 | 44 |
| 2 | 5 | 1 |
| 2 | 7 | 23 |
-------------------------------
我也有这样的python函数,
def fun(x):
total = 0
result = np.empty_like(x)
for i, y in enumerate(x):
total += (y)
if total < 0:
total = 0
result[i] = total
return result
我想对id_
列上的PySpark数据帧进行分组,并在fun
列上应用函数p
。
我想要类似
spark_df.groupBy('id_')['p'].apply(fun)
我目前正在pyarrow
的帮助下使用pandas udf进行此操作,这对于我的申请时间而言效率不高。
我正在寻找的结果是
[4, 7, 0, 5, 0, 5, 11, -98, -87, -1, 5, 7]
这是我正在寻找的结果数据框,
+----------+--------+---------+
|id_ | p | a |
+----------+--------+---------+
| 1 | 4 | 12 |
| 1 | 7 | 14 |
| 1 | 0 | 16 |
| 1 | 5 | 11 |
| 1 | 0 | 90 |
| 1 | 5 | 120 |
| 2 | 11 | 267 |
| 2 | 0 | 124 |
| 2 | 0 | 120 |
| 2 | 0 | 44 |
| 2 | 5 | 1 |
| 2 | 12 | 23 |
-------------------------------
是否可以使用pyspark API本身直接执行此操作?
我可以使用p
在collect_list
上分组时使用id_
将udf
汇总并列到列表中,并在其上使用explode
并使用p
获取列mysqldump
根据需要在结果数据框中显示。
但是如何保留数据框中的其他列??
答案 0 :(得分:1)
是的,您可以将上述python函数转换为Pyspark UDF。
由于您要返回整数数组,因此将返回类型指定为ArrayType(IntegerType())
很重要。
下面是代码,
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType, collect_list
@udf(returnType=ArrayType(IntegerType()))
def fun(x):
total = 0
result = np.empty_like(x)
for i, y in enumerate(x):
total += (y)
if total < 0:
total = 0
result[i] = total
return result.tolist() # Convert NumPy Array to Python List
由于udf
的输入必须是列表,因此让我们基于'id'对数据进行分组,并将行转换为数组。
df = df.groupBy('id_').agg(collect_list('p'))
df = df.toDF('id_', 'p_') # Assign a new alias name 'p_'
df.show(truncate=False)
输入数据:
+---+------------------------+
|id_|collect_list(p) |
+---+------------------------+
|1 |[4, 3, -7, 5, -20, 5] |
|2 |[11, -98, -87, -1, 5, 7]|
+---+------------------------+
接下来,我们将udf
应用于此数据,
df.select('id_', fun(df.p_)).show(truncate=False)
输出:
+---+--------------------+
|id_|fun(p_) |
+---+--------------------+
|1 |[4, 7, 0, 5, 0, 5] |
|2 |[11, 0, 0, 0, 5, 12]|
+---+--------------------+
答案 1 :(得分:0)
通过以下步骤,我设法达到了所需的结果,
我的DataFrame看起来像这样,
+---+---+---+
|id_| p| a|
+---+---+---+
| 1| 4| 12|
| 1| 3| 14|
| 1| -7| 16|
| 1| 5| 11|
| 1|-20| 90|
| 1| 5|120|
| 2| 11|267|
| 2|-98|124|
| 2|-87|120|
| 2| -1| 44|
| 2| 5| 1|
| 2| 7| 23|
+---+---+---+
我将对id_
上的数据框进行分组,并使用collect_list
收集我想将该功能应用于列表的列,并像这样应用该功能,
agg_df = df.groupBy('id_').agg(F.collect_list('p').alias('collected_p'))
agg_df = agg_df.withColumn('new', fun('collected_p'))
我现在想以某种方式将agg_df
合并到我的原始数据框中。为此,我将首先使用explode获取行中的列new
中的值。
agg_df = agg_df.withColumn('exploded', F.explode('new'))
为了合并,我将使用monotonically_increasing_id
为原始数据帧和id
生成agg_df
。这样,我将为每个数据帧制作idx
,因为两个数据帧的monotonically_increasing_id
不会相同。
agg_df = agg_df.withColumn('id_mono', F.monotonically_increasing_id())
df = df.withColumn('id_mono', F.monotonically_increasing_id())
w = Window().partitionBy(F.lit(0)).orderBy('id_mono')
df = df.withColumn('idx', F.row_number().over(w))
agg_df = agg_df.withColumn('idx', F.row_number().over(w))
df = df.join(agg_df.select('idx', 'exploded'), ['idx']).drop('id_mono', 'idx')
+---+---+---+--------+
|id_| p| a|exploded|
+---+---+---+--------+
| 1| 4| 12| 4|
| 1| 3| 14| 7|
| 1| -7| 16| 0|
| 1| 5| 11| 5|
| 1|-20| 90| 0|
| 1| 5|120| 5|
| 2| 11|267| 11|
| 2|-98|124| 0|
| 2|-87|120| 0|
| 2| -1| 44| 0|
| 2| 5| 1| 5|
| 2| 7| 23| 12|
+---+---+---+--------+
我不确定这是否是直接的方法。如果有人可以为此建议任何优化,那就太好了。