如何在多列上编写Pyspark UDAF?

时间:2017-09-13 01:47:54

标签: apache-spark pyspark spark-dataframe rdd

我在名为end_stats_df的pyspark数据框中有以下数据:

values     start    end    cat1   cat2
10          1        2     A      B
11          1        2     C      B
12          1        2      D     B
510          1        2     D      C
550          1        2     C      B
500          1        2     A      B
80          1        3     A      B

我想以下列方式聚合它:

  • 我想使用" start"和"结束"列作为聚合键
  • 对于每组行,我需要执行以下操作:
    • 计算该组cat1cat2的唯一值数。例如,对于start = 1和end = 2的组,此数字将为4,因为它有A,B,C,D。此数字将存储为{{1 (在这个例子中n = 4)。
    • 对于n字段,对于每个组,我需要对values进行排序,然后选择每个values值,其中n-1是第一个存储的值以上操作。
    • 在汇总结束时,我并不十分关心上述操作后ncat1中的内容。

上述示例的示例输出是:

cat2

如何使用pyspark数据帧完成任务?我假设我需要使用自定义UDAF,对吧?

1 个答案:

答案 0 :(得分:9)

Pyspark不直接支持UDAF,因此我们必须手动进行聚合。

from pyspark.sql import functions as f

def func(values, cat1, cat2):
    n = len(set(cat1 + cat2))
    return sorted(values)[n - 2]


df = spark.read.load('file:///home/zht/PycharmProjects/test/text_file.txt', format='csv', sep='\t', header=True)
df = df.groupBy(df['start'], df['end']).agg(f.collect_list(df['values']).alias('values'),
                                            f.collect_set(df['cat1']).alias('cat1'),
                                            f.collect_set(df['cat2']).alias('cat2'))
df = df.select(df['start'], df['end'], f.UserDefinedFunction(func, StringType())(df['values'], df['cat1'], df['cat2']))