使用带有条件的PySpark窗口函数添加行

时间:2020-02-04 16:43:40

标签: pyspark pyspark-sql pyspark-dataframes

我需要能够将新行添加到PySpark df will值,该值将基于具有公共ID的其他行的内容。最终将有数百万个ID,每个ID都有很多行。我尝试了下面的方法,该方法有效,但似乎过于复杂。

我从以下格式的df开始(但实际上有更多列):

+-------+----------+-------+
|   id  | variable | value |
+-------+----------+-------+
|     1 | varA     |    30 |
|     1 | varB     |     1 |
|     1 | varC     |    -9 |
+-------+----------+-------+

目前,我正在将此df转换为以下格式:

+-----+------+------+------+
|  id | varA | varB | varC |
+-----+------+------+------+
|   1 |   30 |    1 |   -9 |
+-----+------+------+------+

然后在此df上,我可以使用标准的withColumn和when功能根据其他列中的值添加新列。例如:

df = df.withColumn("varD", when((col("varA") > 16) & (col("varC") != -9)), 2).otherwise(1)

哪个会导致:

+-----+------+------+------+------+
|  id | varA | varB | varC | varD |
+-----+------+------+------+------+
|   1 |   30 |    1 |   -9 |    1 |
+-----+------+------+------+------+

然后我可以将此df旋转回原来的格式,从而导致此问题:

+-------+----------+-------+
|   id  | variable | value |
+-------+----------+-------+
|     1 | varA     |    30 |
|     1 | varB     |     1 |
|     1 | varC     |    -9 |
|     1 | varD     |     1 |
+-------+----------+-------+

这行得通,但似乎可以进行成千上万的行,从而导致昂贵且不必要的操作。感觉它应该是可行的,而无需旋转和取消旋转数据。我需要这样做吗?

我已经阅读了有关Window函数的信息,听起来好像它们可能是获得相同结果的另一种方法,但是说实话,我正在努力地开始使用它们。我可以看到如何将它们用于为每个id生成一个值,例如一个和,或找到一个最大值,但还没有找到一种方法来开始应用导致新行的复杂条件。

将很高兴收到任何有关此问题的帮助。

1 个答案:

答案 0 :(得分:1)

您可以使用pandas_udf在已分组的数据上添加/删除行/列,并在pandas udf中实现处理逻辑。

import pyspark.sql.functions as F

row_schema = StructType(
    [StructField("id", IntegerType(), True),
     StructField("variable", StringType(), True),
     StructField("value", IntegerType(), True)]
)

@F.pandas_udf(row_schema, F.PandasUDFType.GROUPED_MAP)
def addRow(pdf):
    val = 1
    if  (len(pdf.loc[(pdf['variable'] == 'varA') & (pdf['value'] > 16)]) > 0 ) & \
        (len(pdf.loc[(pdf['variable'] == 'varC') & (pdf['value'] != -9)]) > 0):
        val = 2
    return pdf.append(pd.Series([1, 'varD', val], index=['id', 'variable', 'value']), ignore_index=True)

df = spark.createDataFrame([[1, 'varA', 30],
                            [1, 'varB', 1],
                            [1, 'varC', -9]
                            ], schema=['id', 'variable', 'value'])

df.groupBy("id").apply(addRow).show()

恢复

+---+--------+-----+
| id|variable|value|
+---+--------+-----+
|  1|    varA|   30|
|  1|    varB|    1|
|  1|    varC|   -9|
|  1|    varD|    1|
+---+--------+-----+