如何使用多列创建pyspark udf?

时间:2018-03-19 11:55:55

标签: apache-spark pyspark user-defined-functions

我需要在一组数据中使用多个列来编写一些custum代码。

我的自定义代码是在值超过阈值时设置标志,但如果标志位于前一个标志的特定时间内,则禁止​​该标志。

以下是一些示例代码:

df = spark.createDataFrame(
    [
        ("a", 1, 0),
        ("a", 2, 1),
        ("a", 3, 1),
        ("a", 4, 1),
        ("a", 5, 1),
        ("a", 6, 0),
        ("a", 7, 1),
        ("a", 8, 1),
        ("b", 1, 0),
        ("b", 2, 1)
    ],
    ["group_col","order_col", "flag_col"]
)
df.show()
+---------+---------+--------+
|group_col|order_col|flag_col|
+---------+---------+--------+
|        a|        1|       0|
|        a|        2|       1|
|        a|        3|       1|
|        a|        4|       1|
|        a|        5|       1|
|        a|        6|       0|
|        a|        7|       1|
|        a|        8|       1|
|        b|        1|       0|
|        b|        2|       1|
+---------+---------+--------+

from pyspark.sql.functions import udf, col, asc
from pyspark.sql.window import Window
def _suppress(dates=None, alert_flags=None, window=2):
    sup_alert_flag = alert_flag
    last_alert_date = None
    for i, alert_flag in enumerate(alert_flag):
        current_date = dates[i]
        if alert_flag == 1:
            if not last_alert_date:
                sup_alert_flag[i] = 1
                last_alert_date = current_date
            elif (current_date - last_alert_date) > window:
                sup_alert_flag[i] = 1
                last_alert_date = current_date
            else:
                sup_alert_flag[i] = 0
        else:
            alert_flag = 0
    return sup_alert_flag

suppress_udf = udf(_suppress, DoubleType())

df_out = df.withColumn("supressed_flag_col", suppress_udf(dates=col("order_col"), alert_flags=col("flag_col"), window=4).Window.partitionBy(col("group_col")).orderBy(asc("order_col")))

df_out.show()

以上失败,但我的预期输出如下:

+---------+---------+--------+------------------+
|group_col|order_col|flag_col|supressed_flag_col|
+---------+---------+--------+------------------+
|        a|        1|       0|                 0|
|        a|        2|       1|                 1|
|        a|        3|       1|                 0|
|        a|        4|       1|                 0|
|        a|        5|       1|                 0|
|        a|        6|       0|                 0|
|        a|        7|       1|                 1|
|        a|        8|       1|                 0|
|        b|        1|       0|                 0|
|        b|        2|       1|                 1|
+---------+---------+--------+------------------+

2 个答案:

答案 0 :(得分:1)

经过深思熟虑后编辑答案。

一般问题似乎是当前行的结果取决于前一行的结果。实际上,存在递归关系。我还没有找到在Spark中实现递归UDF的好方法。 Spark中数据的假定分布式特性导致了一些难以实现的挑战。至少在我的脑海里。以下解决方案应该可以工作,但可能无法扩展到大型数据集。

from pyspark.sql import Row
import pyspark.sql.functions as F
import pyspark.sql.types as T

suppress_flag_row = Row("order_col", "flag_col", "res_flag")

def suppress_flag( date_alert_flags, window_size ):

    sorted_alerts = sorted( date_alert_flags, key=lambda x: x["order_col"])

    res_flags = []
    last_alert_date = None
    for row in sorted_alerts:
        current_date = row["order_col"]
        aflag = row["flag_col"]
        if aflag == 1 and (not last_alert_date or (current_date - last_alert_date) > window_size):
            res = suppress_flag_row(current_date, aflag, True)
            last_alert_date = current_date
        else:
            res = suppress_flag_row(current_date, aflag, False)

        res_flags.append(res)
    return res_flags

in_fields = [T.StructField("order_col", T.IntegerType(), nullable=True )]
in_fields.append( T.StructField("flag_col", T.IntegerType(), nullable=True) )

out_fields = in_fields
out_fields.append(T.StructField("res_flag", T.BooleanType(), nullable=True) )
out_schema = T.StructType(out_fields)
suppress_udf = F.udf(suppress_flag, T.ArrayType(out_schema) )

window_size = 4
tmp = df.groupBy("group_col").agg( F.collect_list( F.struct( F.col("order_col"), F.col("flag_col") ) ).alias("date_alert_flags"))
tmp2 = tmp.select(F.col("group_col"), suppress_udf(F.col("date_alert_flags"), F.lit(window_size)).alias("suppress_res"))

expand_fields = [F.col("group_col")] + [F.col("res_expand")[f.name].alias(f.name) for f in out_fields]
final_df = tmp2.select(F.col("group_col"), F.explode(F.col("suppress_res")).alias("res_expand")).select( expand_fields )

答案 1 :(得分:0)

我认为,您不需要自定义功能。您可以使用rowsBetween选项和窗口来获取5行范围。如果错过了什么,请检查并告诉我。

>>> from pyspark.sql import functions as F
>>> from pyspark.sql import Window

>>> w = Window.partitionBy('group_col').orderBy('order_col').rowsBetween(-5,-1)
>>> df = df.withColumn('supr_flag_col',F.when(F.sum('flag_col').over(w) == 0,1).otherwise(0))
>>> df.orderBy('group_col','order_col').show()
+---------+---------+--------+-------------+
|group_col|order_col|flag_col|supr_flag_col|
+---------+---------+--------+-------------+
|        a|        1|       0|            0|
|        a|        2|       1|            1|
|        a|        3|       1|            0|
|        b|        1|       0|            0|
|        b|        2|       1|            1|
+---------+---------+--------+-------------+